class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
"""TurboQuant attention implementation.
Vectorized PyTorch: batch quantize/store, vectorized bit-unpack
decode with einsum scores and value gather.
"""
supports_quant_query_input: bool = False
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
sliding_window: int | None = None,
kv_cache_dtype: str = "auto",
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
**kwargs,
):
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
from vllm.model_executor.layers.quantization.turboquant.config import (
TurboQuantConfig,
)
self.tq_config = TurboQuantConfig.from_cache_dtype(kv_cache_dtype, head_size)
# Pre-compute kernel constants from config (avoid repeated arithmetic)
cfg = self.tq_config
self._mse_bytes = (
math.ceil(head_size * cfg.key_mse_bits / 8)
if not cfg.key_fp8
else head_size
)
self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8)
self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1
# Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph,
# and benchmarks show no regression vs dynamic in eager mode).
vllm_config = get_current_vllm_config()
self.max_num_kv_splits = (
vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
)
def _ensure_on_device(self, layer, device):
"""One-time derivation of TQ buffers (rotation matrices, midpoints).
Registered buffers (_tq_signs, _tq_centroids) are already on the
correct device via register_buffer + model.to(device).
"""
if not hasattr(layer, "_tq_cached"):
D = layer._tq_signs.shape[0]
signs = layer._tq_signs.to(device=device, dtype=torch.float32)
# WHT rotation: orthonormal + self-inverse, enabling future
# in-kernel butterfly fusion and trivial inverse for continuation.
H = _build_hadamard(D, str(device))
layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous()
layer._tq_Pi = layer._tq_PiT.T.contiguous()
c = layer._tq_centroids.to(device=device, dtype=torch.float32)
# Precompute midpoints for threshold-based quantization
c_sorted, _ = c.sort()
layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
# Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf)
# are pre-allocated via register_buffer in Attention.__init__
# and moved to GPU by model.to(device) — no allocation needed
# here. The memory profiler sees them before KV cache sizing.
layer._tq_cached = True
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Store compressed K/V into the combined TQ cache.
Called as a separate custom op (unified_kv_cache_update) BEFORE
the attention forward, matching FlashAttention's split pattern.
slot_mapping is already sliced to num_actual_tokens by the caller.
"""
N = slot_mapping.shape[0]
if N <= 0:
return
device = key.device
self._ensure_on_device(layer, device)
k = key[:N].view(N, self.num_kv_heads, self.head_size)
v = value[:N].view(N, self.num_kv_heads, self.head_size)
self._store_kv(k, v, kv_cache, slot_mapping, layer)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: "TurboQuantMetadata",
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
num_tokens = query.shape[0]
if output is None:
output = torch.zeros(
num_tokens,
self.num_heads * self.head_size,
dtype=query.dtype,
device=query.device,
)
if attn_metadata is None:
return output.fill_(0)
# Slice to actual tokens
N = attn_metadata.num_actual_tokens
if N <= 0:
return output.fill_(0)
q = query[:N].view(N, self.num_heads, self.head_size)
# Get TQ buffers, ensure on device (one-time migration).
# Use Any-typed alias for dynamic _tq_* attrs set by _ensure_on_device.
tq_layer: Any = layer
device = q.device
self._ensure_on_device(tq_layer, device)
Pi = tq_layer._tq_Pi
PiT = tq_layer._tq_PiT
centroids = tq_layer._tq_centroids
# Compute attention (KV cache was already updated by do_kv_cache_update)
# With reorder_batch_threshold=1, decodes come first in the batch.
# num_decodes/num_decode_tokens from metadata give the split point.
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens
if not attn_metadata.is_prefill:
# Pure decode batch — fast path
attn_out = self._decode_attention(
q, kv_cache, attn_metadata, Pi, centroids, PiT, layer
)
elif num_decodes == 0:
# Pure prefill batch
k = key[:N].view(N, self.num_kv_heads, self.head_size)
v = value[:N].view(N, self.num_kv_heads, self.head_size)
attn_out = self._prefill_attention(
q,
k,
v,
kv_cache,
attn_metadata,
Pi,
centroids,
PiT,
layer=layer,
)
else:
# Mixed batch: decodes first (guaranteed by reorder_batch).
attn_out = torch.zeros(
N, self.num_heads, self.head_size, device=device, dtype=q.dtype
)
# --- Decode portion (first num_decodes requests) ---
# Use full-batch max_seq_len as safe upper bound (no GPU sync).
decode_meta = TurboQuantMetadata(
seq_lens=attn_metadata.seq_lens[:num_decodes],
slot_mapping=attn_metadata.slot_mapping[:num_decode_tokens],
block_table=attn_metadata.block_table[:num_decodes],
query_start_loc=attn_metadata.query_start_loc[: num_decodes + 1],
num_actual_tokens=num_decode_tokens,
max_query_len=1,
max_seq_len=attn_metadata.max_seq_len,
is_prefill=False,
)
attn_out[:num_decode_tokens] = self._decode_attention(
q[:num_decode_tokens], kv_cache, decode_meta, Pi, centroids, PiT, layer
)
# --- Prefill portion (remaining requests) ---
# CRITICAL: use prefill-specific max_seq_len so flash_attn's
# fast path (max_query_len == max_seq_len) triggers for
# first-chunk prefills. Using full-batch max_seq_len breaks
# this because decode requests inflate max_seq_len.
prefill_seq_lens = attn_metadata.seq_lens[num_decodes:]
# Use CPU-side max to avoid GPU→CPU sync from .item()
prefill_max_seq = max(attn_metadata.seq_lens[num_decodes:].tolist())
prefill_qsl = (
attn_metadata.query_start_loc[num_decodes:] - num_decode_tokens
)
prefill_meta = TurboQuantMetadata(
seq_lens=prefill_seq_lens,
slot_mapping=attn_metadata.slot_mapping[num_decode_tokens:N],
block_table=attn_metadata.block_table[num_decodes:],
query_start_loc=prefill_qsl,
num_actual_tokens=N - num_decode_tokens,
max_query_len=attn_metadata.max_query_len,
max_seq_len=prefill_max_seq,
is_prefill=True,
)
k = key[:N].view(N, self.num_kv_heads, self.head_size)
v = value[:N].view(N, self.num_kv_heads, self.head_size)
attn_out[num_decode_tokens:] = self._prefill_attention(
q[num_decode_tokens:],
k[num_decode_tokens:],
v[num_decode_tokens:],
kv_cache,
prefill_meta,
Pi,
centroids,
PiT,
layer=layer,
)
# Write into output buffer: attn_out is (N, Hq, D)
# output may be 2D (N, Hq*D) or 3D (N, Hq, D)
if output.ndim == 3:
output[:N] = attn_out.to(output.dtype)
else:
output[:N] = attn_out.reshape(N, -1).to(output.dtype)
return output
# ------------------------------------------------------------------ #
# Store K/V into combined cache (vectorized) #
# ------------------------------------------------------------------ #
def _store_kv(
self,
key: torch.Tensor, # (N, Hk, D)
value: torch.Tensor, # (N, Hk, D)
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
slot_mapping: torch.Tensor,
layer: Any,
):
"""Quantize + store via fused Triton kernel."""
triton_turboquant_store(
key,
value,
kv_cache,
slot_mapping,
layer._tq_PiT,
layer._tq_midpoints,
mse_bits=self.tq_config.key_mse_bits,
key_packed_size=self.tq_config.key_packed_size,
value_quant_bits=self.tq_config.effective_value_quant_bits,
key_fp8=self.tq_config.key_fp8,
)
# ------------------------------------------------------------------ #
# Prefill: SDPA on raw Q/K/V with causal mask #
# ------------------------------------------------------------------ #
def _prefill_attention(
self,
query: torch.Tensor, # (N, Hq, D)
key: torch.Tensor, # (N, Hk, D)
value: torch.Tensor, # (N, Hk, D)
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
attn_metadata: TurboQuantMetadata,
Pi: torch.Tensor,
centroids: torch.Tensor,
PiT: torch.Tensor | None = None,
layer: Any = None,
) -> torch.Tensor:
N, Hq, D = query.shape
# Fast path: use flash_attn for first-chunk prefills (all K/V in batch).
# max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync.
if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype)
flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
out=output,
)
return output
# Continuation or no flash_attn: per-request attention.
# For continuation chunks (seq_len > q_len), we must attend to
# previously cached K/V from the TQ cache, not just the current
# chunk's raw K/V.
Hk = key.shape[1]
use_gqa = Hk < Hq
query_start_loc = attn_metadata.query_start_loc
num_reqs = query_start_loc.shape[0] - 1
output = torch.zeros(N, Hq, D, device=query.device, dtype=query.dtype)
# Convert to Python lists once (single CPU-GPU sync) instead of
# per-request .item() calls that each force a sync.
qsl = query_start_loc.tolist()
seq_lens_list = attn_metadata.seq_lens.tolist()
# Pre-allocate cu_seqlens for single-request flash_attn calls
# to avoid per-request host→device tensor creation.
_cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32)
for i in range(num_reqs):
q_start = qsl[i]
q_end = qsl[i + 1]
q_len = q_end - q_start
if q_len <= 0:
continue
seq_len = seq_lens_list[i]
q_seq = query[q_start:q_end] # (q_len, Hq, D)
k_seq = key[q_start:q_end] # (q_len, Hk, D)
v_seq = value[q_start:q_end] # (q_len, Hk, D)
if q_len == seq_len:
# First-chunk prefill: all K/V are in the current batch.
if _HAS_FLASH_ATTN:
out = torch.empty_like(q_seq)
_cu_2[1] = q_len
cu = _cu_2
flash_attn_varlen_func(
q=q_seq,
k=k_seq,
v=v_seq,
cu_seqlens_q=cu,
cu_seqlens_k=cu,
max_seqlen_q=q_len,
max_seqlen_k=q_len,
softmax_scale=self.scale,
causal=True,
out=out,
)
else:
q_t = q_seq.transpose(0, 1).contiguous()
k_t = k_seq.transpose(0, 1).contiguous()
v_t = v_seq.transpose(0, 1).contiguous()
out = F.scaled_dot_product_attention(
q_t,
k_t,
v_t,
is_causal=True,
scale=self.scale,
enable_gqa=use_gqa,
).transpose(0, 1)
output[q_start:q_end] = out.to(query.dtype)
else:
# Continuation chunk: tokens already stored to TQ cache
# by do_kv_cache_update. Use decode kernel directly to
# avoid O(cached_len) full-dequant per continuation.
# For large continuations, fall back to _continuation_prefill.
cached_len = seq_len - q_len
if q_len <= _CONTINUATION_DECODE_THRESHOLD:
# Fast path: treat each query as a decode request
# with incremental seq_lens for causal masking.
synth_seq_lens = torch.arange(
cached_len + 1,
seq_len + 1,
device=query.device,
dtype=attn_metadata.seq_lens.dtype,
)
synth_bt = attn_metadata.block_table[i : i + 1].expand(q_len, -1)
out = triton_turboquant_decode_attention(
query=q_seq,
kv_cache=kv_cache,
block_table=synth_bt,
seq_lens=synth_seq_lens,
Pi=Pi,
centroids=centroids,
scale=self.scale,
mse_bits=self.tq_config.key_mse_bits,
key_packed_size=self.tq_config.key_packed_size,
value_quant_bits=(self.tq_config.effective_value_quant_bits),
key_fp8=self.tq_config.key_fp8,
norm_correction=self.tq_config.norm_correction,
PiT=PiT,
)
else:
# Large continuation: dequant cached K/V and use
# flash_attn for better throughput.
out = self._continuation_prefill(
layer,
q_seq,
k_seq,
v_seq,
kv_cache,
attn_metadata.block_table[i : i + 1],
cached_len,
seq_len,
Pi,
centroids,
)
output[q_start:q_end] = out.to(query.dtype)
return output
def _continuation_prefill(
self,
layer: Any,
query: torch.Tensor, # (q_len, Hq, D)
key_chunk: torch.Tensor, # (q_len, Hk, D)
val_chunk: torch.Tensor, # (q_len, Hk, D)
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
block_table: torch.Tensor, # (1, max_num_blocks)
cached_len: int,
seq_len: int,
Pi: torch.Tensor,
centroids: torch.Tensor,
) -> torch.Tensor:
"""Handle continuation chunk by dequanting cached K/V from TQ cache.
Dequants previously cached K/V, concatenates with the current
chunk's raw K/V, then runs flash_attn with causal masking.
"""
q_len, Hq, D = query.shape
Hk = key_chunk.shape[1]
device = query.device
block_size = kv_cache.shape[1]
BLOCK_D = triton.next_power_of_2(D)
mse_bytes = self._mse_bytes
val_data_bytes = self._val_data_bytes
# Dequant cached K/V from TQ cache
# Allocate slightly over to align to block_size for the grid.
# Reuse cached buffers to avoid per-call allocation (~16MB at 8K).
alloc_len = math.ceil(cached_len / block_size) * block_size
buf_shape = (1, Hk, alloc_len, D)
k_buf = getattr(layer, "_tq_k_dequant_buf", None)
if k_buf is None or k_buf.shape[2] < alloc_len:
k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
layer._tq_k_dequant_buf = k_buf
layer._tq_v_dequant_buf = v_buf
else:
v_buf = layer._tq_v_dequant_buf
k_cached = k_buf[:, :, :alloc_len, :].zero_()
v_cached = v_buf[:, :, :alloc_len, :].zero_()
grid = (alloc_len, 1 * Hk)
_tq_full_dequant_kv[grid](
kv_cache,
block_table,
centroids,
k_cached,
v_cached,
k_cached.stride(0),
k_cached.stride(1),
k_cached.stride(2),
v_cached.stride(0),
v_cached.stride(1),
v_cached.stride(2),
kv_cache.stride(0),
kv_cache.stride(1),
kv_cache.stride(2),
block_table.stride(0),
HEAD_DIM=D,
BLOCK_SIZE=block_size,
NUM_KV_HEADS=Hk,
MSE_BYTES=mse_bytes,
KPS=self.tq_config.key_packed_size,
VQB=self.tq_config.effective_value_quant_bits,
VAL_DATA_BYTES=val_data_bytes,
MSE_BITS=self.tq_config.key_mse_bits,
KEY_FP8=1 if self.tq_config.key_fp8 else 0,
BLOCK_D=BLOCK_D,
NORM_CORRECTION=1 if self.tq_config.norm_correction else 0,
FP8_E4B15=_use_fp8_e4b15(device.index or 0),
num_warps=4,
)
# Inverse-rotate MSE keys back to original space
if not self.tq_config.key_fp8:
k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float()
k_flat = k_flat @ Pi
k_cached_trim = (
k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1)
) # (cached_len, Hk, D)
else:
k_cached_trim = (
k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
) # (cached_len, Hk, D)
v_cached_trim = (
v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
) # (cached_len, Hk, D)
# Concatenate cached + current chunk K/V (match query dtype)
qdtype = query.dtype
k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0)
v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0)
# Attention: q_len queries attending to seq_len K/V with causal mask
if _HAS_FLASH_ATTN:
output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype)
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
flash_attn_varlen_func(
q=query,
k=k_full,
v=v_full,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=seq_len,
softmax_scale=self.scale,
causal=True,
out=output,
)
return output
else:
# SDPA fallback: expand KV for GQA, build causal mask
q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D)
k_t = k_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D)
v_t = v_full.transpose(0, 1).unsqueeze(0) # (1, Hk, seq_len, D)
# Build causal mask: query position p can attend to K position j
# where j <= cached_len + p (p is 0-indexed within chunk)
q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len
k_pos = torch.arange(seq_len, device=device).unsqueeze(0)
mask = k_pos <= q_pos # (q_len, seq_len)
out = F.scaled_dot_product_attention(
q_t,
k_t,
v_t,
attn_mask=mask,
scale=self.scale,
enable_gqa=(Hk < Hq),
) # (1, Hq, q_len, D)
return out[0].transpose(0, 1) # (q_len, Hq, D)
# ------------------------------------------------------------------ #
# Decode: Triton TQ decode attention #
# ------------------------------------------------------------------ #
def _decode_attention(
self,
query: torch.Tensor, # (B, Hq, D)
kv_cache: torch.Tensor, # (num_blocks, block_size, Hk, slot_size)
attn_metadata: TurboQuantMetadata,
Pi: torch.Tensor,
centroids: torch.Tensor,
PiT: torch.Tensor | None = None,
layer: torch.nn.Module | None = None,
) -> torch.Tensor:
# Grab cached decode buffers from the layer (lazily allocated).
mid_o_buf = output_buf = lse_buf = None
if layer is not None:
mid_o_buf = getattr(layer, "_tq_mid_o_buf", None)
output_buf = getattr(layer, "_tq_output_buf", None)
lse_buf = getattr(layer, "_tq_lse_buf", None)
result = triton_turboquant_decode_attention(
query=query,
kv_cache=kv_cache,
block_table=attn_metadata.block_table,
seq_lens=attn_metadata.seq_lens,
Pi=Pi,
centroids=centroids,
scale=self.scale,
mse_bits=self.tq_config.key_mse_bits,
key_packed_size=self.tq_config.key_packed_size,
value_quant_bits=self.tq_config.effective_value_quant_bits,
key_fp8=self.tq_config.key_fp8,
norm_correction=self.tq_config.norm_correction,
PiT=PiT,
mid_o_buf=mid_o_buf,
output_buf=output_buf,
lse_buf=lse_buf,
buf_holder=layer,
max_num_kv_splits=self.max_num_kv_splits,
)
return result