Class CrossEncoderLoader

java.lang.Object
com.redis.vl.utils.rerank.CrossEncoderLoader

public class CrossEncoderLoader extends Object
Loads and runs ONNX cross-encoder models for document reranking.

Cross-encoders are different from embedding models: - Input: Query + Document pair (tokenized together with token_type_ids) - Output: Single relevance score (not an embedding vector)

This loader handles BertForSequenceClassification models exported to ONNX.

  • Constructor Summary

    Constructors
    Constructor
    Description
    Creates a new CrossEncoderLoader.
  • Method Summary

    Modifier and Type
    Method
    Description
    ai.onnxruntime.OrtSession
    loadModel(Path modelDir, ai.onnxruntime.OrtEnvironment env)
    Load an ONNX cross-encoder model from the specified directory.
    float
    runInference(ai.onnxruntime.OrtSession session, long[][] inputIds, long[][] tokenTypeIds, long[][] attentionMask)
    Run inference to get relevance score.
    Map<String,long[][]>
    tokenizePair(String query, String document)
    Tokenize a query-document pair for cross-encoder input.

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • Constructor Details

    • CrossEncoderLoader

      public CrossEncoderLoader()
      Creates a new CrossEncoderLoader.
  • Method Details

    • loadModel

      public ai.onnxruntime.OrtSession loadModel(Path modelDir, ai.onnxruntime.OrtEnvironment env) throws IOException, ai.onnxruntime.OrtException
      Load an ONNX cross-encoder model from the specified directory.
      Parameters:
      modelDir - Path to the directory containing the ONNX model files
      env - The ONNX runtime environment to use
      Returns:
      The loaded ONNX runtime session
      Throws:
      IOException - if model files cannot be read
      ai.onnxruntime.OrtException - if the ONNX runtime fails to load the model
    • tokenizePair

      public Map<String,long[][]> tokenizePair(String query, String document)
      Tokenize a query-document pair for cross-encoder input.

      Uses HuggingFace tokenizer to properly encode the query-document pair with special tokens, attention masks, and token type IDs.

      Parameters:
      query - The query text
      document - The document text
      Returns:
      Map containing input_ids, token_type_ids, attention_mask
    • runInference

      public float runInference(ai.onnxruntime.OrtSession session, long[][] inputIds, long[][] tokenTypeIds, long[][] attentionMask) throws ai.onnxruntime.OrtException
      Run inference to get relevance score.
      Parameters:
      session - The ONNX session
      inputIds - Input token IDs
      tokenTypeIds - Token type IDs (0 for query, 1 for document) - only used for BERT models
      attentionMask - Attention mask
      Returns:
      Relevance score
      Throws:
      ai.onnxruntime.OrtException - if inference fails