Pixeltable provides extensive built-in support for popular embedding models, but you can also easily integrate your own custom embedding models. This guide shows you how to create and use custom embedding functions for any model architecture.

Quick Start

Here’s a simple example using a custom BERT model:

import tensorflow as tf
import tensorflow_hub as hub
import pixeltable as pxt

@pxt.udf
def custom_bert_embed(text: str) -> pxt.Array[(512,), pxt.Float]:
    """Basic BERT embedding function"""
    preprocessor = hub.load('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')
    model = hub.load('https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2')
    
    tensor = tf.constant([text])
    result = model(preprocessor(tensor))['pooled_output']
    return result.numpy()[0, :]

# Create table and add embedding index
docs = pxt.create_table('documents', {'text': pxt.String})
docs.add_embedding_index('text', string_embed=custom_bert_embed)

Production Best Practices

The quick start example works but isn’t production-ready. Below we’ll cover how to optimize your custom embedding UDFs.

Model Caching

Always cache your model instances to avoid reloading on every call:

@pxt.udf
def optimized_bert_embed(text: str) -> pxt.Array[(512,), pxt.Float]:
    """BERT embedding function with model caching"""
    if not hasattr(optimized_bert_embed, 'model'):
        # Load models once
        optimized_bert_embed.preprocessor = hub.load(
            'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
        )
        optimized_bert_embed.model = hub.load(
            'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2'
        )
    
    tensor = tf.constant([text])
    result = optimized_bert_embed.model(
        optimized_bert_embed.preprocessor(tensor)
    )['pooled_output']
    return result.numpy()[0, :]

Batch Processing

Use Pixeltable’s batching capabilities for better performance:

from pixeltable.func import Batch

@pxt.udf(batch_size=32)
def batched_bert_embed(texts: Batch[str]) -> Batch[pxt.Array[(512,), pxt.Float]]:
    """BERT embedding function with batching"""
    if not hasattr(batched_bert_embed, 'model'):
        batched_bert_embed.preprocessor = hub.load(
            'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
        )
        batched_bert_embed.model = hub.load(
            'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/2'
        )
    
    # Process entire batch at once
    tensor = tf.constant(list(texts))
    results = batched_bert_embed.model(
        batched_bert_embed.preprocessor(tensor)
    )['pooled_output']
    return [r for r in results.numpy()]

Error Handling

Always implement proper error handling in production UDFs:

@pxt.udf
def robust_bert_embed(text: str) -> pxt.Array[(512,), pxt.Float]:
    """BERT embedding with error handling"""
    try:
        if not text or len(text.strip()) == 0:
            raise ValueError("Empty text input")
            
        if not hasattr(robust_bert_embed, 'model'):
            # Model initialization...
            pass
            
        tensor = tf.constant([text])
        result = robust_bert_embed.model(
            robust_bert_embed.preprocessor(tensor)
        )['pooled_output']
        
        return result.numpy()[0, :]
        
    except Exception as e:
        logger.error(f"Embedding failed: {str(e)}")
        raise

Additional Resources