🪆 Introduction to Matryoshka Embedding Models
Matryoshka embedding models store more important information in earlier dimensions, less important in later ones — like Russian nesting dolls. Truncated embeddings still retain useful information, unlike traditional models.
Why use them?
- Shortlisting + reranking: shrink embeddings for fast candidate retrieval, then process finalists at full dimensionality
- Trade-offs: tune storage cost, speed, and performance independently
Training
Standard training: produce embeddings → compute loss on full-size → update weights.
MRL training: produce embeddings → compute loss at multiple dimensionalities (e.g., 768, 512, 256, 128, 64) → sum losses → update weights. This frontloads the most important information.
from sentence_transformers.losses import CoSENTLoss, MatryoshkaLoss
loss = MatryoshkaLoss(
model=model,
loss=CoSENTLoss(model=model),
matryoshka_dims=[768, 512, 256, 128, 64],
)Usage
model = SentenceTransformer("tomaarsen/mpnet-base-nli-matryoshka", truncate_dim=64)
embeddings = model.encode(["The weather is nice!", "It's sunny outside!"])
# shape: (2, 64)Re-normalize after manual truncation if original embeddings were normalized.
Results (STSBenchmark)
- Matryoshka model exceeds standard model at all dimensionalities
- At 8.3% of embedding size (64d of 768d): 98.37% performance preserved vs 96.46% for standard model
Significant storage savings and downstream speedups with negligible accuracy loss.