Skip to content

vllm.v1.attention.ops.vit_attn_wrappers

This file contains ops for ViT attention to be compatible with torch.compile as there are operations here not supported by torch.compile (for instance, .item() in flash attention)

Using these ops and wrapping vision blocks with torch.compile can speed up throughput in vision models by ~5% relative on H100, and improve token latencies by ~7% (see qwen2_5_vl for example usage)

To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)

apply_sdpa

apply_sdpa(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    scale: float | None = None,
    enable_gqa: bool = False,
) -> Tensor

Input shape: (batch_size x seq_len x num_heads x head_size)

Source code in vllm/v1/attention/ops/vit_attn_wrappers.py
def apply_sdpa(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    scale: float | None = None,
    enable_gqa: bool = False,
) -> torch.Tensor:
    """
    Input shape:
    (batch_size x seq_len x num_heads x head_size)
    """
    q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
    output = F.scaled_dot_product_attention(
        q, k, v, dropout_p=0.0, scale=scale, enable_gqa=enable_gqa
    )
    output = einops.rearrange(output, "b h s d -> b s h d ")
    return output