Package com.redis.vl.utils.rerank
Class CrossEncoderLoader
java.lang.Object
com.redis.vl.utils.rerank.CrossEncoderLoader
Loads and runs ONNX cross-encoder models for document reranking.
Cross-encoders are different from embedding models: - Input: Query + Document pair (tokenized together with token_type_ids) - Output: Single relevance score (not an embedding vector)
This loader handles BertForSequenceClassification models exported to ONNX.
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionai.onnxruntime.OrtSessionLoad an ONNX cross-encoder model from the specified directory.floatrunInference(ai.onnxruntime.OrtSession session, long[][] inputIds, long[][] tokenTypeIds, long[][] attentionMask) Run inference to get relevance score.tokenizePair(String query, String document) Tokenize a query-document pair for cross-encoder input.
-
Constructor Details
-
CrossEncoderLoader
public CrossEncoderLoader()Creates a new CrossEncoderLoader.
-
-
Method Details
-
loadModel
public ai.onnxruntime.OrtSession loadModel(Path modelDir, ai.onnxruntime.OrtEnvironment env) throws IOException, ai.onnxruntime.OrtException Load an ONNX cross-encoder model from the specified directory.- Parameters:
modelDir- Path to the directory containing the ONNX model filesenv- The ONNX runtime environment to use- Returns:
- The loaded ONNX runtime session
- Throws:
IOException- if model files cannot be readai.onnxruntime.OrtException- if the ONNX runtime fails to load the model
-
tokenizePair
Tokenize a query-document pair for cross-encoder input.Uses HuggingFace tokenizer to properly encode the query-document pair with special tokens, attention masks, and token type IDs.
- Parameters:
query- The query textdocument- The document text- Returns:
- Map containing input_ids, token_type_ids, attention_mask
-
runInference
public float runInference(ai.onnxruntime.OrtSession session, long[][] inputIds, long[][] tokenTypeIds, long[][] attentionMask) throws ai.onnxruntime.OrtException Run inference to get relevance score.- Parameters:
session- The ONNX sessioninputIds- Input token IDstokenTypeIds- Token type IDs (0 for query, 1 for document) - only used for BERT modelsattentionMask- Attention mask- Returns:
- Relevance score
- Throws:
ai.onnxruntime.OrtException- if inference fails
-