Skip to content

vllm.model_executor.layers.fla.ops.solve_tril

solve_tril

solve_tril(
    A: Tensor,
    cu_seqlens: Tensor | None = None,
    output_dtype: dtype = float,
) -> Tensor

Compute the inverse of the matrix I + A A should be strictly lower triangular, i.e., A.triu() == 0.

Parameters:

Name Type Description Default
A Tensor

[B, T, H, BT], where BT should only be 16, 32, or 64.

required
cu_seqlens Tensor

The cumulative sequence lengths of the input tensor. Default: None.

None
output_dtype dtype

The dtype of the output tensor. Default: torch.float. If None, the output dtype will be the same as the input dtype.

float

Returns:

Type Description
Tensor

(I + A)^-1 with the same shape as A

Source code in vllm/model_executor/layers/fla/ops/solve_tril.py
@input_guard
def solve_tril(
    A: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
    """
    Compute the inverse of the matrix I + A
    A should be strictly lower triangular, i.e., A.triu() == 0.

    Args:
        A (torch.Tensor):
            [B, T, H, BT], where BT should only be 16, 32, or 64.
        cu_seqlens (torch.Tensor):
            The cumulative sequence lengths of the input tensor. Default: `None`.
        output_dtype (torch.dtype):
            The dtype of the output tensor. Default: `torch.float`.
            If `None`, the output dtype will be the same as the input dtype.

    Returns:
        (I + A)^-1 with the same shape as A
    """
    assert A.shape[-1] in [16, 32, 64]
    output_dtype = A.dtype if output_dtype is None else output_dtype

    B, T, H, BT = A.shape
    chunk_indices = (
        prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
    )
    NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)

    Ai = torch.zeros_like(A, dtype=output_dtype)
    if BT == 16:
        merge_fn = solve_tril_16x16_kernel
    elif BT == 32:
        merge_fn = merge_16x16_to_32x32_inverse_kernel
    elif BT == 64:
        merge_fn = merge_16x16_to_64x64_inverse_kernel

    merge_fn[NT, B * H](
        A=A,
        Ai=Ai,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        T=T,
        H=H,
        BT=BT,
        USE_TMA=is_tma_supported,
        DOT_PRECISION=FLA_TRIL_PRECISION,
    )
    return Ai