flash_attention.py
SDPA: Scaled Dot-Product Attention
def _sdpa_attention(q, k, v, window_size, enable_gqa):
"""
SDPA attention with sliding window support.
q, k, v are (B, H, T, D) format.
window_size: (left, right) sliding window. -1 means unlimited.
"""
Tq = q.size(2) ## (B,H,T,D),取T
Tk = k.size(2)
window = window_size[0] ## window格式(left,right),取left
## prefill阶段
# Full context, same length
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
# Single token generation
if Tq == 1:
if window >= 0 and window < Tk:
# window is "left" tokens we need to include (window + 1) keys total
start = max(0, Tk - (window + 1)) # 第Tk的下标是Tk-1,窗口长(window+1)
k = k[:, :, start:, :]
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) # k和v切片已经不包含未来token,所以可`is_causal=False`
# Need explicit mask for sliding window/chunk inference
device = q.device
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
chunk解码时,Tk=历史长度+Tq
row_idx = (Tk - Tq) + arange(Tq).unsqueeze(1) ## query
col_idx = arange(Tk).unsqueeze(0) ## key
mask = col_idx <= row_idx- col_idx:所有 key 的列号 0..Tk-1
- row_idx:每个 query 在“全序列里的绝对位置”
- 关键是 Tk - Tq 这个偏移量:把 query 行对齐到 cache 末尾(bottom-right 对齐)
- mask = col_idx <= row_idx:每个 query 只能看不超过自己位置的 key(因果)
例子:Tk=8, Tq=3
- Tk-Tq = 5
- row_idx = [5,6,7](shape 3x1)
- col_idx = [0,1,2,3,4,5,6,7](shape 1x8)
得到 mask(T=True, F=False):
- row0 (<=5): T T T T T T F F
- row1 (<=6): T T T T T T T F
- row2 (<=7): T T T T T T T T
含义:这 3 个 query 分别对应全局位置 5/6/7。再加一层“左滑窗”限制
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)- row_idx - col_idx = query 位置和 key 位置的距离(往左看多远)
- 只保留距离 <= window 的 key
推理时更新某层的k_cache和v_cache
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
causal=False, window_size=(-1, -1)):
"""
Flash Attention with KV cache for inference.
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
Args:
q: Queries, shape (B, T_new, H, D)
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
cache_seqlens: Current position in cache, shape (B,) int32
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T_new, H, D)
"""
if USE_FA3:
return _fa3.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
# SDPA fallback: manually manage KV cache
B, T_new, H, D = q.shape
pos = cache_seqlens[0].item() # assume uniform position across batch ## 所以只取cache_seqlens[0]
# Insert new k, v into cache (in-place, matching FA3 behavior)
if k is not None and v is not None:
k_cache[:, pos:pos+T_new, :, :] = k ## 只在T维度追加,k和k_cache其他维度都相同
v_cache[:, pos:pos+T_new, :, :] = v
# Get full cache up to current position + new tokens
end_pos = pos + T_new
k_full = k_cache[:, :end_pos, :, :]
v_full = v_cache[:, :end_pos, :, :]
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
q_sdpa = q.transpose(1, 2)
k_sdpa = k_full.transpose(1, 2)
v_sdpa = v_full.transpose(1, 2)
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)engine.py
class KVCache:
"""
KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
Key differences from FA2-style cache:
- Tensors are (B, T, H, D) not (B, H, T, D)
- FA3 updates the cache in-place during flash_attn_with_kvcache
- Position tracked per batch element via cache_seqlens tensor
"""
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
# ...
# Pre-allocate cache tensors: (n_layers, B, T, H, D)
self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
# Current sequence length per batch element (FA3 needs int32)
self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)- 每层都有自己的
k_cache和v_cache - 全局
cache_seqlens记录cache写入位置
gpt.py
class CausalSelfAttention(nn.Module):
def forward(self, x, ve, cos_sin, window_size, kv_cache):
# ...
else: ## using kv_cache
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
window_size=window_size,
)
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T) ## 更新cache_seqlens- 最后一层推理完成才更新cache写入位置