Package com.redis.vl.utils.vectorize
Class OnnxModelLoader
java.lang.Object
com.redis.vl.utils.vectorize.OnnxModelLoader
Loads and runs ONNX models for generating embeddings. Handles tokenization, inference, and
post-processing.
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptiongetEmbedding
(String text) Get embedding for a single text.getEmbeddings
(List<String> texts) Get embeddings for multiple texts.int
Get the hidden size (same as embedding dimension).com.google.gson.JsonObject
Get a copy of the tokenizer configuration to prevent internal representation exposure.ai.onnxruntime.OrtSession
Load an ONNX model from the specified directory.ai.onnxruntime.OrtSession
Load an ONNX model from the specified directory with a specific environment.float[][]
meanPooling
(float[][][] tokenEmbeddings) Deprecated.Use meanPoolingWithAttention for correct sentence-transformers behaviorfloat[][]
meanPoolingWithAttention
(float[][][] tokenEmbeddings, long[][] attentionMask) Apply attention-masked mean pooling to token embeddings (Sentence Transformers style).float[][]
normalize
(float[][] embeddings) Normalize embeddings to unit length.float[][]
runInference
(ai.onnxruntime.OrtSession session, ai.onnxruntime.OnnxTensor inputTensor) Run inference on the model with tokenized input.long[][]
Tokenize a single text string.long[][]
tokenizeBatch
(List<String> texts) Tokenize a batch of text strings.
-
Constructor Details
-
OnnxModelLoader
public OnnxModelLoader()Creates a new OnnxModelLoader. The ONNX runtime environment will be provided when loading the model.
-
-
Method Details
-
loadModel
public ai.onnxruntime.OrtSession loadModel(Path modelDir) throws IOException, ai.onnxruntime.OrtException Load an ONNX model from the specified directory.- Parameters:
modelDir
- Path to the directory containing the ONNX model files- Returns:
- The loaded ONNX runtime session
- Throws:
IOException
- if model files cannot be readai.onnxruntime.OrtException
- if the ONNX runtime fails to load the model
-
loadModel
public ai.onnxruntime.OrtSession loadModel(Path modelDir, ai.onnxruntime.OrtEnvironment env) throws IOException, ai.onnxruntime.OrtException Load an ONNX model from the specified directory with a specific environment.- Parameters:
modelDir
- Path to the directory containing the ONNX model filesenv
- The ONNX runtime environment to use- Returns:
- The loaded ONNX runtime session
- Throws:
IOException
- if model files cannot be readai.onnxruntime.OrtException
- if the ONNX runtime fails to load the model
-
getHiddenSize
public int getHiddenSize()Get the hidden size (same as embedding dimension).- Returns:
- The hidden size of the model
-
getTokenizer
public com.google.gson.JsonObject getTokenizer()Get a copy of the tokenizer configuration to prevent internal representation exposure.- Returns:
- A deep copy of the tokenizer configuration
-
tokenize
Tokenize a single text string.- Parameters:
text
- The text to tokenize- Returns:
- A 2D array of token IDs (batch size 1)
-
tokenizeBatch
Tokenize a batch of text strings.- Parameters:
texts
- List of texts to tokenize- Returns:
- A 2D array of token IDs, one row per input text
-
runInference
public float[][] runInference(ai.onnxruntime.OrtSession session, ai.onnxruntime.OnnxTensor inputTensor) throws ai.onnxruntime.OrtException Run inference on the model with tokenized input.- Parameters:
session
- The ONNX runtime sessioninputTensor
- The input tensor containing token IDs- Returns:
- 2D array of embeddings, one row per input
- Throws:
ai.onnxruntime.OrtException
- if inference fails
-
meanPoolingWithAttention
public float[][] meanPoolingWithAttention(float[][][] tokenEmbeddings, long[][] attentionMask) Apply attention-masked mean pooling to token embeddings (Sentence Transformers style). Only averages over non-padding tokens where attention_mask == 1.- Parameters:
tokenEmbeddings
- 3D array of token embeddings [batch, sequence, hidden]attentionMask
- 2D array of attention mask [batch, sequence] (1=real token, 0=padding)- Returns:
- 2D array of pooled embeddings [batch, hidden]
-
meanPooling
Deprecated.Use meanPoolingWithAttention for correct sentence-transformers behaviorApply mean pooling to token embeddings (legacy method).- Parameters:
tokenEmbeddings
- 3D array of token embeddings [batch, sequence, hidden]- Returns:
- 2D array of pooled embeddings [batch, hidden]
-
normalize
public float[][] normalize(float[][] embeddings) Normalize embeddings to unit length.- Parameters:
embeddings
- 2D array of embeddings to normalize- Returns:
- Normalized embeddings with unit length
-
getEmbedding
Get embedding for a single text.- Parameters:
text
- The text to generate an embedding for- Returns:
- List of floats representing the embedding vector
- Throws:
ai.onnxruntime.OrtException
- if inference fails
-
getEmbeddings
Get embeddings for multiple texts.- Parameters:
texts
- List of texts to generate embeddings for- Returns:
- List of embedding vectors
- Throws:
ai.onnxruntime.OrtException
- if inference fails
-