Class OnnxModelLoader

java.lang.Object
com.redis.vl.utils.vectorize.OnnxModelLoader

public class OnnxModelLoader extends Object
Loads and runs ONNX models for generating embeddings. Handles tokenization, inference, and post-processing.
  • Constructor Summary

    Constructors
    Constructor
    Description
    Creates a new OnnxModelLoader.
  • Method Summary

    Modifier and Type
    Method
    Description
    Get embedding for a single text.
    Get embeddings for multiple texts.
    int
    Get the hidden size (same as embedding dimension).
    com.google.gson.JsonObject
    Get a copy of the tokenizer configuration to prevent internal representation exposure.
    ai.onnxruntime.OrtSession
    loadModel(Path modelDir)
    Load an ONNX model from the specified directory.
    ai.onnxruntime.OrtSession
    loadModel(Path modelDir, ai.onnxruntime.OrtEnvironment env)
    Load an ONNX model from the specified directory with a specific environment.
    float[][]
    meanPooling(float[][][] tokenEmbeddings)
    Deprecated.
    Use meanPoolingWithAttention for correct sentence-transformers behavior
    float[][]
    meanPoolingWithAttention(float[][][] tokenEmbeddings, long[][] attentionMask)
    Apply attention-masked mean pooling to token embeddings (Sentence Transformers style).
    float[][]
    normalize(float[][] embeddings)
    Normalize embeddings to unit length.
    float[][]
    runInference(ai.onnxruntime.OrtSession session, ai.onnxruntime.OnnxTensor inputTensor)
    Run inference on the model with tokenized input.
    long[][]
    Tokenize a single text string.
    long[][]
    Tokenize a batch of text strings.

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • Constructor Details

    • OnnxModelLoader

      public OnnxModelLoader()
      Creates a new OnnxModelLoader. The ONNX runtime environment will be provided when loading the model.
  • Method Details

    • loadModel

      public ai.onnxruntime.OrtSession loadModel(Path modelDir) throws IOException, ai.onnxruntime.OrtException
      Load an ONNX model from the specified directory.
      Parameters:
      modelDir - Path to the directory containing the ONNX model files
      Returns:
      The loaded ONNX runtime session
      Throws:
      IOException - if model files cannot be read
      ai.onnxruntime.OrtException - if the ONNX runtime fails to load the model
    • loadModel

      public ai.onnxruntime.OrtSession loadModel(Path modelDir, ai.onnxruntime.OrtEnvironment env) throws IOException, ai.onnxruntime.OrtException
      Load an ONNX model from the specified directory with a specific environment.
      Parameters:
      modelDir - Path to the directory containing the ONNX model files
      env - The ONNX runtime environment to use
      Returns:
      The loaded ONNX runtime session
      Throws:
      IOException - if model files cannot be read
      ai.onnxruntime.OrtException - if the ONNX runtime fails to load the model
    • getHiddenSize

      public int getHiddenSize()
      Get the hidden size (same as embedding dimension).
      Returns:
      The hidden size of the model
    • getTokenizer

      public com.google.gson.JsonObject getTokenizer()
      Get a copy of the tokenizer configuration to prevent internal representation exposure.
      Returns:
      A deep copy of the tokenizer configuration
    • tokenize

      public long[][] tokenize(String text)
      Tokenize a single text string.
      Parameters:
      text - The text to tokenize
      Returns:
      A 2D array of token IDs (batch size 1)
    • tokenizeBatch

      public long[][] tokenizeBatch(List<String> texts)
      Tokenize a batch of text strings.
      Parameters:
      texts - List of texts to tokenize
      Returns:
      A 2D array of token IDs, one row per input text
    • runInference

      public float[][] runInference(ai.onnxruntime.OrtSession session, ai.onnxruntime.OnnxTensor inputTensor) throws ai.onnxruntime.OrtException
      Run inference on the model with tokenized input.
      Parameters:
      session - The ONNX runtime session
      inputTensor - The input tensor containing token IDs
      Returns:
      2D array of embeddings, one row per input
      Throws:
      ai.onnxruntime.OrtException - if inference fails
    • meanPoolingWithAttention

      public float[][] meanPoolingWithAttention(float[][][] tokenEmbeddings, long[][] attentionMask)
      Apply attention-masked mean pooling to token embeddings (Sentence Transformers style). Only averages over non-padding tokens where attention_mask == 1.
      Parameters:
      tokenEmbeddings - 3D array of token embeddings [batch, sequence, hidden]
      attentionMask - 2D array of attention mask [batch, sequence] (1=real token, 0=padding)
      Returns:
      2D array of pooled embeddings [batch, hidden]
    • meanPooling

      @Deprecated public float[][] meanPooling(float[][][] tokenEmbeddings)
      Deprecated.
      Use meanPoolingWithAttention for correct sentence-transformers behavior
      Apply mean pooling to token embeddings (legacy method).
      Parameters:
      tokenEmbeddings - 3D array of token embeddings [batch, sequence, hidden]
      Returns:
      2D array of pooled embeddings [batch, hidden]
    • normalize

      public float[][] normalize(float[][] embeddings)
      Normalize embeddings to unit length.
      Parameters:
      embeddings - 2D array of embeddings to normalize
      Returns:
      Normalized embeddings with unit length
    • getEmbedding

      public List<Float> getEmbedding(String text) throws ai.onnxruntime.OrtException
      Get embedding for a single text.
      Parameters:
      text - The text to generate an embedding for
      Returns:
      List of floats representing the embedding vector
      Throws:
      ai.onnxruntime.OrtException - if inference fails
    • getEmbeddings

      public List<List<Float>> getEmbeddings(List<String> texts) throws ai.onnxruntime.OrtException
      Get embeddings for multiple texts.
      Parameters:
      texts - List of texts to generate embeddings for
      Returns:
      List of embedding vectors
      Throws:
      ai.onnxruntime.OrtException - if inference fails