Skip to content

vllm.model_executor.layers.fla.ops.fused_recurrent

fused_recurrent_gated_delta_rule

fused_recurrent_gated_delta_rule(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    g: Tensor,
    beta: Tensor = None,
    scale: float = None,
    initial_state: Tensor = None,
    inplace_final_state: bool = True,
    cu_seqlens: LongTensor | None = None,
    ssm_state_indices: Tensor | None = None,
    num_accepted_tokens: Tensor | None = None,
    use_qk_l2norm_in_kernel: bool = False,
) -> tuple[Tensor, Tensor]

Parameters:

Name Type Description Default
q Tensor

queries of shape [B, T, H, K].

required
k Tensor

keys of shape [B, T, H, K].

required
v Tensor

values of shape [B, T, HV, V]. GVA is applied if HV > H.

required
g Tensor

g (decays) of shape [B, T, HV].

required
beta Tensor

betas of shape [B, T, HV].

None
scale Optional[int]

Scale factor for the RetNet attention scores. If not provided, it will default to 1 / sqrt(K). Default: None.

None
initial_state Optional[Tensor]

Initial state of shape [N, HV, V, K] for N input sequences. For equal-length input sequences, N equals the batch size B. Default: None.

None
inplace_final_state bool

bool: Whether to store the final state in-place to save memory. Default: True.

True
cu_seqlens LongTensor

Cumulative sequence lengths of shape [N+1] used for variable-length training, consistent with the FlashAttention API.

None
ssm_state_indices Optional[Tensor]

Indices to map the input sequences to the initial/final states.

None
num_accepted_tokens Optional[Tensor]

Number of accepted tokens for each sequence during decoding.

None

Returns:

Name Type Description
o Tensor

Outputs of shape [B, T, HV, V].

final_state Tensor

Final state of shape [N, HV, V, K].

Examples:: >>> import torch >>> import torch.nn.functional as F >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule # inputs with equal lengths >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 >>> q = torch.randn(B, T, H, K, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, HV, V, device='cuda') >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() >>> h0 = torch.randn(B, HV, V, K, device='cuda') >>> o, ht = fused_gated_recurrent_delta_rule( q, k, v, g, beta, initial_state=h0, ) # for variable-length inputs, the batch size B is expected to be 1 and cu_seqlens is required >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) # for a batch with 4 sequences, cu_seqlens with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) >>> o_var, ht_var = fused_gated_recurrent_delta_rule( q, k, v, g, beta, initial_state=h0, cu_seqlens=cu_seqlens )

Source code in vllm/model_executor/layers/fla/ops/fused_recurrent.py
def fused_recurrent_gated_delta_rule(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor = None,
    scale: float = None,
    initial_state: torch.Tensor = None,
    inplace_final_state: bool = True,
    cu_seqlens: torch.LongTensor | None = None,
    ssm_state_indices: torch.Tensor | None = None,
    num_accepted_tokens: torch.Tensor | None = None,
    use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    r"""
    Args:
        q (torch.Tensor):
            queries of shape `[B, T, H, K]`.
        k (torch.Tensor):
            keys of shape `[B, T, H, K]`.
        v (torch.Tensor):
            values of shape `[B, T, HV, V]`.
            GVA is applied if `HV > H`.
        g (torch.Tensor):
            g (decays) of shape `[B, T, HV]`.
        beta (torch.Tensor):
            betas of shape `[B, T, HV]`.
        scale (Optional[int]):
            Scale factor for the RetNet attention scores.
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
        initial_state (Optional[torch.Tensor]):
            Initial state of shape `[N, HV, V, K]` for `N` input sequences.
            For equal-length input sequences, `N` equals the batch size `B`.
            Default: `None`.
        inplace_final_state: bool:
            Whether to store the final state in-place to save memory.
            Default: `True`.
        cu_seqlens (torch.LongTensor):
            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
            consistent with the FlashAttention API.
        ssm_state_indices (Optional[torch.Tensor]):
            Indices to map the input sequences to the initial/final states.
        num_accepted_tokens (Optional[torch.Tensor]):
            Number of accepted tokens for each sequence during decoding.

    Returns:
        o (torch.Tensor):
            Outputs of shape `[B, T, HV, V]`.
        final_state (torch.Tensor):
            Final state of shape `[N, HV, V, K]`.

    Examples::
        >>> import torch
        >>> import torch.nn.functional as F
        >>> from einops import rearrange
        >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
        # inputs with equal lengths
        >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
        >>> q = torch.randn(B, T, H, K, device='cuda')
        >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
        >>> v = torch.randn(B, T, HV, V, device='cuda')
        >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
        >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
        >>> h0 = torch.randn(B, HV, V, K, device='cuda')
        >>> o, ht = fused_gated_recurrent_delta_rule(
            q, k, v, g, beta,
            initial_state=h0,
        )
        # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
        >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
        # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
        >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
        >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
            q, k, v, g, beta,
            initial_state=h0,
            cu_seqlens=cu_seqlens
        )
    """
    if cu_seqlens is not None and q.shape[0] != 1:
        raise ValueError(
            f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
            f"Please flatten variable-length inputs before processing."
        )
    if scale is None:
        scale = k.shape[-1] ** -0.5
    else:
        assert scale > 0, "scale must be positive"
    if beta is None:
        beta = torch.ones_like(q[..., 0])
    o, final_state = FusedRecurrentFunction.apply(
        q,
        k,
        v,
        g,
        beta,
        scale,
        initial_state,
        inplace_final_state,
        cu_seqlens,
        ssm_state_indices,
        num_accepted_tokens,
        use_qk_l2norm_in_kernel,
    )
    return o, final_state