September 4, 2025 in embedding 5 minutes
gemini-embedding-001 (Google, 2025), text-embedding-3-large (2024, OpenAI), voyage-context-3 (2025, Voyage AI) ๋ฑ ์ต์ ์๋ฒ ๋ฉ ๋ชจ๋ธ์์ Matryoshka Representation Learning; MRL๋ฅผ ์ง์ํ๋ค๊ณ ํ๋๋ฐ MRL์ด ๋ญ๊น?
Matryoshka Representation Learning(MRL)
๋ ๋ฌ์์ ์ธํ ๋งํธ๋ฃ์์นด ์ฒ๋ผ ํ๋์ ์๋ฒ ๋ฉ ๋ฒกํฐ ์์ ์ฌ๋ฌ ์ธ๋ถํ๋ ์ ๋ณด๋ฅผ ๋ด์ ๋ค์ด์คํธ๋ฆผ ์์
(downstream task)์ ์ฐ์ฐ ์ ์ฝ ์กฐ๊ฑด์ ์ ๋์ ์ผ๋ก ์ ์ํ ์ ์๋๋ก ์ค๊ณ๋ ํํ ํ์ต(representation learning) ๊ธฐ๋ฒ์ด๋ค.
*“multi embedding"์ด๋ผ๊ณ ๋ ๋ถ๋ฅด๊ธฐ๋ ํ๋ค. ๋ ผ๋ฌธ์์๋ multi-objective MRL๋ก ํํํ๋ค *downstream task๋ ํ์ต๋ ์๋ฒ ๋ฉ์ ๋ถ๋ฅ/๊ฒ์/๋ญํน ๋ฑ๊ณผ ๊ฐ์ ํ์ ์์ (ํํ ํ๋ จ๋ ๋ชจ๋ธ์ ๋ค์ด์คํธ๋ฆผ ํ ์คํฌ์ ๋ง๊ฒ ํ์ธํ๋ํ๋ค๊ณ ํ๋ค.)
ย
ํฐ ์๋ฒ ๋ฉ ์์ ๊ทธ ์์ฒด๋ก๋ ์ ์ฉํ ์์ ์๋ฒ ๋ฉ๋ค์ด ๊ฒน๊ฒน์ด ๋ค์ด๊ฐ ์์ด ์ํฉ์ ๋ง๊ฒ ๊บผ๋ด ์ฌ์ฉํ ์ ์๋ค๋ ๊ฒ์ด๋ค. โ์ฑ ์ฝ๊ธฐโ๋ก ๋น์ ํ์๋ฉด 32์ฐจ์์ผ๋ก ์ฑ ํ์ง์ ๋ชฉ์ฐจ๋ฅผ ์ดํด๋ณด๊ณ , ๋ด์ฉ์ ๋ ์ฝ๊ณ ์ถ์ผ๋ฉด 128์ฐจ์๊น์ง ์ฑ ์ ํผ์ณ๋ณด๊ณ , ๊ทธ๋ผ์๋ ๋ถ์กฑํ๋ฉด ๋ถ๋ก, ์ฆ ์ต์ข ์ฐจ์๊น์ง ๋ณด๋ ๊ฒ์ด๋ค.
์ผ๋ฐ์ ์ผ๋ก ์๋ฒ ๋ฉ ์ฐจ์์ด ๋์์๋ก ์ฑ๋ฅ์ด ์ฌ๋ผ๊ฐ์ง๋ง, ๊ทธ๋ด์๋ก ๋น์ฉ๊ณผ ์๋๋ ํจ๊ป ์ฆ๊ฐํ๋ค. ๋ฐ๋ผ์ ์ํฉ์ ๋ง๊ฒ ์ ์ฐํ๊ฒ ์ ์ฐจ์, ๊ณ ์ฐจ์์ ์ ํํ ์ ์๋๋ก ํ๋ ๊ฒ์ด๋ค.
์๋ฒ ๋ฉ ํ๋ ์์ ๊ฐ๋ตํ ์ ๋ณด(coarse)๋ถํฐ ์์ธํ ์ ๋ณด(fine)๊น์ง ์์๋๋ก ๋ด์๋์!
๊ทธ๋์ ๊ณ ์ฐจ์๋ง ์ฌ์ฉํ๋ ๊ฒ ์๋๋ผ ์ โ์คโ๊ณ ์ฐจ์์ ์ํฉ์ ๋ง๊ฒ ์ฐ์!
๋ณธ ๋ ผ๋ฌธ์์๋ rigidity โ flexibility ์ ๊ฐ๋ ์ผ๋ก ์ค๋ช ํ๋๋ฐ, ๋ฌด์์ ๋งํ๋๊ฑธ๊น?
๊ณผ๊ฑฐ ์๋ฒ ๋ฉ ๊ณ ์ ์ฐจ์(fixed dimension)์ด๋ผ, ๋ค๋ฅธ ์ฐจ์์ด ํ์ํ ๊ฒฝ์ฐ ์ด์ ๋ง๊ฒ ๋ชจ๋ธ์ ๋ค์ ํ๋ จํด์ผํ๋ค. (์ค๋ฒํค๋ โ)
MRL์ ๊ณ์ฐ ์์๊ณผ ์ํฉ์ ๋ง์ถฐ ๋ฐ๋ก ํ๋์ ๋ชจ๋ธ & ํ๋์ ์๋ฒ ๋ฉ ๋ฒกํฐ๋ก ํ์ํ ์ฐจ์์ ์ ๋ณด๋ฅผ ์ ํํ์ฌ ์ฌ์ฉํ ์ ์์ด ์ ์ฐํ๋ฉฐ ํจ์จ์ ์ด๋ค.
์๋ฒ ๋ฉ ํ๋๋ก coarse โ fine (์ ๋ฐ์ โ ์ธ๋ถ์ ์ผ๋ก) ์์ฝ ์ ๋ณด์์ ์ธ๋ถ์ ์ธ ์ ๋ณด๋ฅผ ๋ด๋ ๊ฒ์ด๋ค. ๋ค์ํ ์ธ๋ถ์ฑ(granularity)๋ฅผ ๊ณ ๋ คํ์ฌ ์ ๋ณด๋ฅผ ๊ณ์ธต์ ์ผ๋ก ์ธ์ฝ๋ฉํ๋ค.
์๋ฒ ๋ฉ ์ฐจ์ $d$, ๋ฐ์ดํฐ ํฌ๊ธฐ $N$, ๋ผ๋ฒจ ์ $L$์ ๋น๋กํ์ฌ ์ ํ์ ์ผ๋ก ์ถ๋ก ๋น์ฉ์ด ์ฆ๊ฐํ๋ค.
๊ฒฝ์ฌ ๊ธฐ๋ฐ ํ์ต์ ์๋ฏธ ์๋ ์ ๋ณด๊ฐ ๋ฒกํฐ ์ ์ญ์ผ๋ก ํผ์ง๋ ๊ฒฝํฅ์ด ์์ด ๋ฎ์ ์ฐจ์๋ง์ผ๋ก ์ถฉ๋ถํ ์์ ์์๋ ํฐ ์ฐจ์์ ๊ฐ์ ํ๋ ๋นํจ์จ์ด ๋ฐ์ํ๋ค.
*์ฌ๊ธฐ์์ ์ถ๋ก ๋น์ฉ์ ๊ณ์ฐ๋ deep representation(๊ณ ์ฐจ์ ๋ฒกํฐ; ๋ฐ์ดํฐ์ ๋ณธ์ง์ ์ด๊ณ ์๋ฏธ ์๋ ์ ๋ณด๊ฐ ์์ถ๋ ํํ๋ก ํํ; ๊ณ ์์ค ํํ)์ ์ค์ ๋ค์ด์คํธ๋ฆผ ์ ํ๋ฆฌ์ผ์ด์ (downstream application)์ ํ์ฉํ ๋ ๋ฐ์ํ๋ ๋น์ฉ์ ๋งํ๋ค.
MRL์ ์ถ๊ฐ ์ถ๋ก ๋น์ฉ ์์ด ๊ธฐ์กด ํ์ดํ๋ผ์ธ์ ์กฐ๊ธ ์์ ํด ์ ์ํ ์๋ฒ ๋ฉ(adaptive embeddings)์ ๋ง๋ค ์ ์๋ค.
MRL์ ๊ณ ์ฐจ์ ์๋ฒ ๋ฉ ๋ฒกํฐ ์์ coarse-to-fine granularity ์์ค์ ์ ๋ณด๋ฅผ ๊ณ์ธต์ ์ผ๋ก ์ธ์ฝ๋ฉํ๋ค.
$O(\\log d)$ : d-์ฐจ์ ๋ฒกํฐ ์์ $M=\\{8,16,32,\\dots,d\\}$ ์ ์ ํด, ๊ฐ $m \\in \\mathcal{M}$์๋ํด ์ $m$๊ฐ $z_{1:m}$๋ง ์ฌ์ฉํด๋ ์ ์ฉํ๋๋ก ํ์ตํ๋ค. ์๋ฅผ ๋ค์ด 1024 ์ฐจ์์ ์ต์ข ์๋ฒ ๋ฉ ๋ฒกํฐ๊ฐ ์๋ค๋ฉด, ๊ทธ ์ค ์ฒ์ 512์ฐจ์๋ง ์ฌ์ฉํด๋ ํน์ ๋ชฉ์ ์ ์์ ํ๊ณ ์ ์ฉํ ์๋ฒ ๋ฉ์ผ๋ก ์ฌ์ฉ ๊ฐ๋ฅํ๋ค.
$$ \min_{\theta_F, \{W^{(m)}\}} \frac{1}{N} \sum_{m \in \mathcal{M}} \sum_{i} c_m \mathcal{L}(W^{(m)} F(x_i; \theta_F)_{1:m}, y_i) $$๋จ์ํ ์ต์ข ์ฐจ์์ ๋ํ ์์ค ํจ์(loss function)๋ง ์ต์ ํํ๋ ๊ฒ์ด ์๋๋ผ ๋ฏธ๋ฆฌ ์ ํด๋ ์ฌ๋ฌ ์ค์ฒฉ๋(nested) ์ฐจ์๋ค ๊ฐ๊ฐ์ ๋ํด ๋์์ ์์ค ํจ์๋ฅผ ์ต์ ํํ๋๋ก ๋ชจ๋ธ ํ๋ จํ๋ค. ๋ชจ๋ ์ค๊ฐ ์ฐจ์์ ๊ฐ๊ฐ์ ๋ชจ๋ธ์ ๋ํด ๋ ๋ฆฝ์ ์ผ๋ก ํ์ตํ์ง ์๊ณ ๋ ๋์ผํ ์ฑ๋ฅ์ ์ ์งํ ์ ์๋ค. ์ด์ฒ๋ผ ํ๋์ ๋ชจ๋ธ๋ก ํ ๋ฒ์ forward pass๋ง์ผ๋ก ๋ชจ๋ ๊ณ์ธต์ ํํ์ ์ป์ด ์ถ๋ก ์ ์๋นํ ๊ณ์ฐ ๋น์ฉ์ ์ ๊ฐํ ์ ์๋ค.
# PyTorch code for Matryoshka Cross-Entropy Loss
import torch.nn as nn
class Matryoshka_CE_Loss(nn.Module):
def __init__(self, relative_importance, **kwargs):
super(Matryoshka_CE_Loss, self).__init__()
self.criterion = nn.CrossEntropyLoss(**kwargs)
self.relative_importance = relative_importance
def forward(self, output, target):
loss = 0
for i in range(len(output)):
loss += self.relative_importance[i] * self.criterion(output[i], target)
return loss
์ผ๋ฐ ์๋ฒ ๋ฉ์ ์ ์ฐจ์์ ๊ทธ๋ฅ ์๋ผ๋ด ์ฌ์ฉํ๋ฉด, ๊ทธ ์ ์ฐจ์ ๋ฒกํฐ๋ ์๋ณธ ์ ๋ณด๋ฅผ ์ ๋๋ก ๋ด์ง ๋ชปํ ๊ฐ๋ฅ์ฑ์ด ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก ๊ฒฝ์ฌ ๊ธฐ๋ฐ ํ์ต ๋ชจ๋ธ์ ์๋ฒ ๋ฉ ๋ฒกํฐ์ ๋ชจ๋ ์ฐจ์์ ๊ฑธ์ณ ์ ๋ณด๊ฐ ํ์ฐ(diffuse)๋์ด ์ธ์ฝ๋ฉ๋๋ ๊ฒฝํฅ์ด ์๊ธฐ ๋๋ฌธ์ด๋ค. ๊ทธ๋์ ์ฒ์ ๋ช ์ฐจ์๋ง ๋ผ์ด๋ด ์ฌ์ฉํ๋ฉด ์ ๋ณด์ ํ์ง์ ๋ณด์ฅํ ์ ์๋ค.
PCA/SVD๊ณผ ๊ฐ์ ์ฌํ ์์ถ์ ์ฐจ์์ ์กฐ๊ธ ์ค์ด๋ฉด ์ ํ๋๊ฐ ์ฌ๋ผ๊ฐ์ง๋ง, ๊ณผํ๊ฒ ์ค์ด๋ฉด ์ ํ๋๊ฐ ๋ง์ด ๊ฐ์ํ๋ค. ๋ฐ๋ฉด, MRL์ ํ์ต ์ดํ ๋จ๊ณ๊ฐ ์๋๋ผ end-to-end ํ์ต ๋จ๊ณ์์ ๋ฏธ๋ฆฌ ์ฌ๋ฌ ์ฐจ์์์ ์ฌ์ฉํ ์ ์๋๋ก ์ต์ ํ๋์ด ์์ด ์ ํ๋๊ฐ ์ ์ง๋๋ค.
25๋ 9์์ ์๊ฐ๋ EmbeddingGemma ๋ชจ๋ธ๋ MRL์ ์ง์ํ๋๋ฐ, ๊ณต์ ๋ฌธ์์์ MRL ์ค๋ช ์ด ๊ฐ๋ตํ ์ ๋์ด ์๋ค.
# MRL test for `google/embeddinggemma-300M`
import os, numpy as np, torch
from sentence_transformers import SentenceTransformer
MODEL_ID = "google/embeddinggemma-300M"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOKEN = os.getenv("HF_TOKEN") # ํ์ ์ ์ค์
# Load Model
model = SentenceTransformer(MODEL_ID, device=DEVICE, token=TOKEN)
data = ["์์ดํฐ", "๊ฐค๋ญ์", "์ผ์ฑ", "๊ณ ์์ด"]
def l2norm(E): # ํ๋ณ L2 ์ ๊ทํ
return E / (np.linalg.norm(E, axis=1, keepdims=True) + 1e-12)
def cosine_to_anchor(E):
a = E[0]
sims = []
for i in range(1, len(E)):
s = float(np.dot(a, E[i]) / (np.linalg.norm(a)*np.linalg.norm(E[i]) + 1e-12))
sims.append((data[i], s))
return sims
def show(title, E):
sims = cosine_to_anchor(E)
print(f"\n[{title}] shape={E.shape}")
for name, s in sims:
print(f" {data[0]} vs {name}: {s:.4f}")
order = [name for name, s in sorted(sims, key=lambda x: x[1], reverse=True)]
print(" rank:", " > ")
return np.array([s for _, s in sims])
def spearman(u, v): # ์์ ์์ ์ฑ ๊ฐ๋จ ์งํ
r = lambda x: np.argsort(np.argsort(-x))
return float(np.corrcoef(r(u), r(v))[0, 1])
# ===== 1/ full embedding =====
emb_full = model.encode(data, convert_to_numpy=True)
D = emb_full.shape[1]
s_full = show("FULL", emb_full)
# ===== 2/ truncate to 512 dims + L2 normalization =====
E512 = l2norm(emb_full[:, :min(512, D)])
s_512 = show("TRUNCATE 512 + L2", E512)
# ===== 3/ truncate to 256 dims + L2 normalization =====
E256 = l2norm(emb_full[:, :min(256, D)])
s_256 = show("TRUNCATE 256 + L2", E256)
# check MRL
print(f"\nSpearman(FULL vs 512) = {spearman(s_full, s_512):.3f}")
print(f"Spearman(FULL vs 256) = {spearman(s_full, s_256):.3f}")
print(f"Base dim = {D}")
[FULL] shape=(4, 768)
์์ดํฐ vs ๊ฐค๋ญ์: 0.9355
์์ดํฐ vs ์ผ์ฑ: 0.9326
์์ดํฐ vs ๊ณ ์์ด: 0.8970
rank: ์์ดํฐ > ๊ฐค๋ญ์ > ์ผ์ฑ > ๊ณ ์์ด
[TRUNCATE 512 + L2] shape=(4, 512)
์์ดํฐ vs ๊ฐค๋ญ์: 0.9442
์์ดํฐ vs ์ผ์ฑ: 0.9419
์์ดํฐ vs ๊ณ ์์ด: 0.9133
rank: ์์ดํฐ > ๊ฐค๋ญ์ > ์ผ์ฑ > ๊ณ ์์ด
[TRUNCATE 256 + L2] shape=(4, 256)
์์ดํฐ vs ๊ฐค๋ญ์: 0.9568
์์ดํฐ vs ์ผ์ฑ: 0.9548
์์ดํฐ vs ๊ณ ์์ด: 0.9270
rank: ์์ดํฐ > ๊ฐค๋ญ์ > ์ผ์ฑ > ๊ณ ์์ด
Spearman(FULL vs 512) = 1.000
Spearman(FULL vs 256) = 1.000
Base dim = 768
์ค์ ๋ก EmbeddingGemma ๋ชจ๋ธ์์ MRL์ด ์ ๋์ํ๋์ง ํ์ธํ ๊ฒฐ๊ณผ ์ฐจ์์ ๊ธฐ์กด 768์ฐจ์์์ 256์ฐจ์์ผ๋ก ์ถ์ํ์๋๋ฐ๋ ์๋ฏธ๋ฅผ ์ ๋ํ๋ด๊ณ ์๋ค๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.