多头hash
class MultiHeadHashEmbedding:
"""
Multi-head hashing for N-gram embeddings.
Uses K independent hash functions to map each N-gram to K different
embedding vectors, then concatenates them. This reduces the impact
of hash collisions since it's unlikely that two N-grams will collide
across all K hash functions simultaneously.
"""
def __init__(self, config: EngramConfig, seed: int = 42):
self.config = config
# Initialize hash function parameters (multiplicative-XOR hash)
# Each (n-gram order, hash head) pair has its own hash parameters
self._init_hash_params()
self._init_embedding_tables()
def compute_ngram_hash(self, ngram_tokens: Tuple[int, ...], n: int, k: int) -> int:
"""Compute hash index using multiplicative-XOR hash: h(x) = ((a * x) XOR b) mod M"""
a, b = self.hash_params[(n, k)]
# Combine token IDs using polynomial rolling hash
combined = 0
for token_id in ngram_tokens:
combined = (combined * 31 + token_id) & 0x7FFFFFFF
return ((a * combined) ^ b) % self.config.table_size
def lookup_embedding(self, token_ids: List[int], position: int) -> np.ndarray:
"""Retrieve aggregated N-gram embedding for a position."""
all_embeddings = []
for n, ngram in self.extract_ngrams(token_ids, position):
# Multi-head hashing: K independent hash functions per N-gram order
for k in range(self.config.num_hash_heads):
hash_idx = self.compute_ngram_hash(ngram, n, k)
embedding = self.embedding_tables[(n, k)][hash_idx]
all_embeddings.append(embedding)
# Concatenate embeddings from all heads and N-gram orders
return np.concatenate(all_embeddings, axis=-1)用具体例子来解释:
---
假设条件
# 假设配置
config.num_hash_heads = 4 # K = 4 个哈希头
config.embedding_dim = 64 # 每个头输出 64 维向量
config.table_size = 10000 # 哈希表大小 M = 10000
# 假设 (n=3, k=0) 这个表的第 3377 行存的是向量:
embedding_tables[(3, 0)][3377] = [0.1, -0.2, 0.5, ...64个维度...]
---
现在有个 3-gram:["我", "爱", "学习"]
假设 token ID 分别是:[102, 558, 923]
步骤 1:提取 N-gram
extract_ngrams(token_ids, position)
# 返回: [(3, (102, 558, 923))] # 3-gram,token IDs 是元组
步骤 2:对这个 3-gram 计算 4 个哈希头
# 对每个哈希头 k=0,1,2,3 分别计算
# k=0: hash_idx = 3377
hash_idx_0 = compute_ngram_hash((102, 558, 923), n=3, k=0)
# → 假设返回 3377
# k=1: hash_idx = 8821
hash_idx_1 = compute_ngram_hash((102, 558, 923), n=3, k=1)
# → 假设返回 8821
# k=2: hash_idx = 115
hash_idx_2 = compute_ngram_hash((102, 558, 923), n=3, k=2)
# → 假设返回 115
# k=3: hash_idx = 6699
hash_idx_3 = compute_ngram_hash((102, 558, 923), n=3, k=3)
# → 假设返回 6699
步骤 3:从 4 个嵌入表中查向量
# 从 4 个不同的表查
vec0 = embedding_tables[(3, 0)][3377] # 64维
vec1 = embedding_tables[(3, 1)][8821] # 64维
vec2 = embedding_tables[(3, 2)][115] # 64维
vec3 = embedding_tables[(3, 3)][6699] # 64维
步骤 4:拼接
result = np.concatenate([vec0, vec1, vec2, vec3], axis=-1)
# → 4 × 64 = 256 维向量
---
总结
| 概念 | 说明 |
|------|------|
| 3-gram | 一串 3 个 token,如 "我 爱 学习" |
| 哈希头 K=4 | 用 4 个不同公式算出 4 个索引 |
| 每个索引 | 去对应表查一个 64 维向量 |
| 最终输出 | 4 个向量拼起来 = 256 维 |
关键点:同一个 3-gram,用 k=0 和 k=1 查的是不同的哈希表,存放在 embedding_tables[(3, 0)] 和 embedding_tables[(3, 1)]。这 4 个表是独立初始化的,所以同一个 N-gram 在不同头下会拿到完全不同的向量。engram
Engram负责模式匹配;MoE负责组合和推理。
class TransformerBlockWithEngram(nn.Module):
"""Transformer block with integrated Engram conditional memory.
Architecture:
h = h + Attention(LayerNorm(h))
h = h + Engram(h, input_ids) # <-- Engram insertion point
h = h + MLP(LayerNorm(h))
"""