#AI#code

flash_attention.py

SDPA: Scaled Dot-Product Attention

def _sdpa_attention(q, k, v, window_size, enable_gqa):
    """
    SDPA attention with sliding window support.
    q, k, v are (B, H, T, D) format.
    window_size: (left, right) sliding window. -1 means unlimited.
    """
    Tq = q.size(2) ## (B,H,T,D),取T
    Tk = k.size(2)
    window = window_size[0] ## window格式(left,right),取left
 
	## prefill阶段
    # Full context, same length
    if (window < 0 or window >= Tq) and Tq == Tk:
        return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
 
    # Single token generation
    if Tq == 1:
        if window >= 0 and window < Tk:
            # window is "left" tokens we need to include (window + 1) keys total
            start = max(0, Tk - (window + 1)) # 第Tk的下标是Tk-1,窗口长(window+1)
            k = k[:, :, start:, :]
            v = v[:, :, start:, :]
        return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) # k和v切片已经不包含未来token,所以可`is_causal=False`
 
    # Need explicit mask for sliding window/chunk inference
    device = q.device
    # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
    row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
    col_idx = torch.arange(Tk, device=device).unsqueeze(0)
    mask = col_idx <= row_idx
 
    # sliding window (left)
    if window >= 0 and window < Tk:
        mask = mask & ((row_idx - col_idx) <= window)
 
    return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
 

chunk解码时,Tk=历史长度+Tq

row_idx = (Tk - Tq) + arange(Tq).unsqueeze(1) ## query
col_idx = arange(Tk).unsqueeze(0) ## key
mask = col_idx <= row_idx
  • col_idx:所有 key 的列号 0..Tk-1
  • row_idx:每个 query 在“全序列里的绝对位置”
    • 关键是 Tk - Tq 这个偏移量:把 query 行对齐到 cache 末尾(bottom-right 对齐)
  • mask = col_idx <= row_idx:每个 query 只能看不超过自己位置的 key(因果)
例子:Tk=8, Tq=3
- Tk-Tq = 5
- row_idx = [5,6,7](shape 3x1)
- col_idx = [0,1,2,3,4,5,6,7](shape 1x8)
得到 mask(T=True, F=False):
- row0 (<=5): T T T T T T F F
- row1 (<=6): T T T T T T T F
- row2 (<=7): T T T T T T T T
含义:这 3 个 query 分别对应全局位置 5/6/7。

再加一层“左滑窗”限制

if window >= 0 and window < Tk:
    mask = mask & ((row_idx - col_idx) <= window)
  • row_idx - col_idx = query 位置和 key 位置的距离(往左看多远)
  • 只保留距离 <= window 的 key

推理时更新某层的k_cache和v_cache

def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
                            causal=False, window_size=(-1, -1)):
    """
    Flash Attention with KV cache for inference.
 
    FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
 
    Args:
        q: Queries, shape (B, T_new, H, D)
        k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
        k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
        cache_seqlens: Current position in cache, shape (B,) int32
        causal: Whether to use causal masking
        window_size: (left, right) sliding window. -1 means unlimited.
 
    Returns:
        Output tensor of shape (B, T_new, H, D)
    """
    if USE_FA3:
        return _fa3.flash_attn_with_kvcache(
            q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
            causal=causal, window_size=window_size
        )
 
    # SDPA fallback: manually manage KV cache
    B, T_new, H, D = q.shape
    pos = cache_seqlens[0].item()  # assume uniform position across batch ## 所以只取cache_seqlens[0]
 
    # Insert new k, v into cache (in-place, matching FA3 behavior)
    if k is not None and v is not None:
        k_cache[:, pos:pos+T_new, :, :] = k ## 只在T维度追加,k和k_cache其他维度都相同
        v_cache[:, pos:pos+T_new, :, :] = v
 
    # Get full cache up to current position + new tokens
    end_pos = pos + T_new
    k_full = k_cache[:, :end_pos, :, :]
    v_full = v_cache[:, :end_pos, :, :]
 
    # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
    q_sdpa = q.transpose(1, 2)
    k_sdpa = k_full.transpose(1, 2)
    v_sdpa = v_full.transpose(1, 2)
 
    enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
    y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
 
    return y_sdpa.transpose(1, 2)  # back to (B, T, H, D)

engine.py

class KVCache:
    """
    KV Cache designed for Flash Attention 3's flash_attn_with_kvcache API.
 
    Key differences from FA2-style cache:
    - Tensors are (B, T, H, D) not (B, H, T, D)
    - FA3 updates the cache in-place during flash_attn_with_kvcache
    - Position tracked per batch element via cache_seqlens tensor
    """
 
    def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers, device, dtype):
        # ...
        # Pre-allocate cache tensors: (n_layers, B, T, H, D)
        self.k_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
        self.v_cache = torch.zeros(num_layers, batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype)
        # Current sequence length per batch element (FA3 needs int32)
        self.cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
  • 每层都有自己的k_cachev_cache
  • 全局cache_seqlens记录cache写入位置

gpt.py

class CausalSelfAttention(nn.Module):
	def forward(self, x, ve, cos_sin, window_size, kv_cache):
		# ...
        else: ## using kv_cache
            # Inference: use flash_attn_with_kvcache which handles cache management
            k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
            y = flash_attn.flash_attn_with_kvcache(
                q, k_cache, v_cache,
                k=k, v=v,
                cache_seqlens=kv_cache.cache_seqlens,
                causal=True,
                window_size=window_size,
            )
            # Advance position after last layer processes
            if self.layer_idx == kv_cache.n_layers - 1:
                kv_cache.advance(T) ## 更新cache_seqlens
  • 最后一层推理完成才更新cache写入位置

参考