Package com.redis.vl.utils.rerank
Class CrossEncoderLoader
java.lang.Object
com.redis.vl.utils.rerank.CrossEncoderLoader
Loads and runs ONNX cross-encoder models for document reranking.
Cross-encoders are different from embedding models: - Input: Query + Document pair (tokenized together with token_type_ids) - Output: Single relevance score (not an embedding vector)
This loader handles BertForSequenceClassification models exported to ONNX.
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionai.onnxruntime.OrtSession
Load an ONNX cross-encoder model from the specified directory.float
runInference
(ai.onnxruntime.OrtSession session, long[][] inputIds, long[][] tokenTypeIds, long[][] attentionMask) Run inference to get relevance score.tokenizePair
(String query, String document) Tokenize a query-document pair for cross-encoder input.
-
Constructor Details
-
CrossEncoderLoader
public CrossEncoderLoader()Creates a new CrossEncoderLoader.
-
-
Method Details
-
loadModel
public ai.onnxruntime.OrtSession loadModel(Path modelDir, ai.onnxruntime.OrtEnvironment env) throws IOException, ai.onnxruntime.OrtException Load an ONNX cross-encoder model from the specified directory.- Parameters:
modelDir
- Path to the directory containing the ONNX model filesenv
- The ONNX runtime environment to use- Returns:
- The loaded ONNX runtime session
- Throws:
IOException
- if model files cannot be readai.onnxruntime.OrtException
- if the ONNX runtime fails to load the model
-
tokenizePair
Tokenize a query-document pair for cross-encoder input.Uses HuggingFace tokenizer to properly encode the query-document pair with special tokens, attention masks, and token type IDs.
- Parameters:
query
- The query textdocument
- The document text- Returns:
- Map containing input_ids, token_type_ids, attention_mask
-
runInference
public float runInference(ai.onnxruntime.OrtSession session, long[][] inputIds, long[][] tokenTypeIds, long[][] attentionMask) throws ai.onnxruntime.OrtException Run inference to get relevance score.- Parameters:
session
- The ONNX sessioninputIds
- Input token IDstokenTypeIds
- Token type IDs (0 for query, 1 for document) - only used for BERT modelsattentionMask
- Attention mask- Returns:
- Relevance score
- Throws:
ai.onnxruntime.OrtException
- if inference fails
-