Skip to content

vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe

Utility helpers for NVFP4 + FlashInfer fused-MoE path

_supports_activation

_supports_activation(activation: str) -> bool

Supports silu activation only.

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def _supports_activation(activation: str) -> bool:
    """Supports silu activation only."""
    return activation in ["silu"]

_supports_current_device

_supports_current_device() -> bool

Supports only Blackwell-family GPUs.

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def _supports_current_device() -> bool:
    """Supports only Blackwell-family GPUs."""
    p = current_platform
    return p.is_cuda() and p.is_device_capability_family(100)

_supports_no_act_and_mul

_supports_no_act_and_mul() -> bool

Does not support non-gated MoE (i.e. Nemotron-Nano).

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def _supports_no_act_and_mul() -> bool:
    """Does not support non-gated MoE (i.e. Nemotron-Nano)."""
    return False

_supports_parallel_config

_supports_parallel_config(
    moe_parallel_config: FusedMoEParallelConfig,
) -> bool

TRTLLM is a monolithic kernel that requires dispatch_router_logits() for the naive dispatch/combine path. DeepEP HT only implements dispatch() for the modular kernel path, so TRTLLM is incompatible with DeepEP HT.

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
    """
    TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
    the naive dispatch/combine path. DeepEP HT only implements dispatch() for
    the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
    """
    return not moe_parallel_config.use_deepep_ht_kernels

_supports_quant_scheme

_supports_quant_scheme(
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
) -> bool

Supports Nvfp4 quantization.

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def _supports_quant_scheme(
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
) -> bool:
    """Supports Nvfp4 quantization."""
    SUPPORTED_W_A = [
        (kNvfp4Static, kNvfp4Dynamic),
    ]
    return (weight_key, activation_key) in SUPPORTED_W_A

_supports_routing_method

_supports_routing_method(
    routing_method: RoutingMethodType,
) -> bool

Monolithic kernels need to express router support.

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def _supports_routing_method(
    routing_method: RoutingMethodType,
) -> bool:
    """Monolithic kernels need to express router support."""
    # NOTE(rob): potentially allow others here. This is a conservative list.
    return routing_method in [
        RoutingMethodType.DeepSeekV3,
        RoutingMethodType.Renormalize,
        RoutingMethodType.RenormalizeNaive,
        RoutingMethodType.Llama4,
    ]

flashinfer_trtllm_fp4_moe

flashinfer_trtllm_fp4_moe(
    layer: Module,
    x: Tensor | tuple[Tensor, Tensor],
    router_logits: Tensor,
    top_k: int,
    activation: str,
    global_num_experts: int,
    num_expert_group: int | None,
    topk_group: int | None,
    custom_routing_function: object | None,
    e_score_correction_bias: Tensor | None,
) -> Tensor

Apply FlashInfer TensorRT-LLM FP4 MoE kernel.

Parameters:

Name Type Description Default
layer Module

The MoE layer with weights and scales

required
x Tensor | tuple[Tensor, Tensor]

Input tensor

required
router_logits Tensor

Router logits for expert selection

required
top_k int

Number of experts to select per token

required
activation str

Activation function to use

required
global_num_experts int

Total number of experts across all ranks

required
num_expert_group int | None

Number of expert groups (for grouped routing)

required
topk_group int | None

Top-k within each group

required
custom_routing_function object | None

Custom routing function (e.g., Llama4)

required
e_score_correction_bias Tensor | None

Optional routing bias correction

required

Returns:

Type Description
Tensor

Output tensor from the MoE layer

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def flashinfer_trtllm_fp4_moe(
    layer: torch.nn.Module,
    x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
    router_logits: torch.Tensor,
    top_k: int,
    activation: str,
    global_num_experts: int,
    num_expert_group: int | None,
    topk_group: int | None,
    custom_routing_function: object | None,
    e_score_correction_bias: torch.Tensor | None,
) -> torch.Tensor:
    """
    Apply FlashInfer TensorRT-LLM FP4 MoE kernel.

    Args:
        layer: The MoE layer with weights and scales
        x: Input tensor
        router_logits: Router logits for expert selection
        top_k: Number of experts to select per token
        activation: Activation function to use
        global_num_experts: Total number of experts across all ranks
        num_expert_group: Number of expert groups (for grouped routing)
        topk_group: Top-k within each group
        custom_routing_function: Custom routing function (e.g., Llama4)
        e_score_correction_bias: Optional routing bias correction

    Returns:
        Output tensor from the MoE layer
    """
    import flashinfer

    from vllm.model_executor.models.llama4 import Llama4MoE

    # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2404
    assert activation == "silu", (
        "Only SiLU activation is supported for FlashInfer TRTLLM FP4 MoE. "
        f"{activation} found instead."
    )

    # Quantize input to FP4
    if isinstance(x, tuple):
        hidden_states_fp4, hidden_states_scale_linear_fp4 = x
    else:
        # hidden_states is the already quantized
        (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
            x, layer.a1_gscale, is_sf_swizzled_layout=False
        )

    # Determine routing method type
    use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
    routing_method_type = layer.routing_method_type
    if use_llama4_routing:
        routing_method_type = flashinfer.RoutingMethodType.Llama4

    # Cast to Fp32 (required by kernel).
    router_logits = (
        router_logits.to(torch.float32)
        if routing_method_type == RoutingMethodType.DeepSeekV3
        else router_logits
    )

    # Call TRT-LLM FP4 block-scale MoE kernel
    out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
        routing_logits=router_logits,
        routing_bias=e_score_correction_bias,
        hidden_states=hidden_states_fp4,
        hidden_states_scale=hidden_states_scale_linear_fp4.view(
            torch.float8_e4m3fn
        ).flatten(),
        gemm1_weights=layer.w13_weight.data,
        gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
        gemm1_bias=None,
        gemm1_alpha=None,
        gemm1_beta=None,
        gemm1_clamp_limit=None,
        gemm2_weights=layer.w2_weight.data,
        gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
        gemm2_bias=None,
        output1_scale_scalar=layer.g1_scale_c.data,
        output1_scale_gate_scalar=layer.g1_alphas.data,
        output2_scale_scalar=layer.g2_alphas.data,
        num_experts=global_num_experts,
        top_k=top_k,
        n_group=num_expert_group if num_expert_group is not None else 0,
        topk_group=topk_group if topk_group is not None else 0,
        intermediate_size=layer.intermediate_size_per_partition,
        local_expert_offset=layer.ep_rank * layer.local_num_experts,
        local_num_experts=layer.local_num_experts,
        routed_scaling_factor=None,
        routing_method_type=routing_method_type,
        do_finalize=True,
    )[0]

    return out

flashinfer_trtllm_fp4_routed_moe

flashinfer_trtllm_fp4_routed_moe(
    layer: Module,
    x: Tensor,
    topk_ids: Tensor,
    topk_weights: Tensor,
    top_k: int,
    activation: str,
    global_num_experts: int,
) -> Tensor

Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed input top k expert indices and scores rather than computing top k expert indices from scores.

Parameters:

Name Type Description Default
layer Module

The MoE layer with weights and scales

required
x Tensor

Input tensor

required
topk_ids Tensor

Ids of selected experts

required
top_k int

Number of experts to select per token

required
activation str

Activation function to use

required
global_num_experts int

Total number of experts across all ranks

required

Returns:

Type Description
Tensor

Output tensor from the MoE layer

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def flashinfer_trtllm_fp4_routed_moe(
    layer: torch.nn.Module,
    x: torch.Tensor,
    topk_ids: torch.Tensor,
    topk_weights: torch.Tensor,
    top_k: int,
    activation: str,
    global_num_experts: int,
) -> torch.Tensor:
    """
    Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
    input top k expert indices and scores rather than computing
    top k expert indices from scores.

    Args:
        layer: The MoE layer with weights and scales
        x: Input tensor
        topk_ids: Ids of selected experts
        top_k: Number of experts to select per token
        activation: Activation function to use
        global_num_experts: Total number of experts across all ranks

    Returns:
        Output tensor from the MoE layer
    """
    import flashinfer

    # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
    assert activation == "silu", (
        "Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
        f"{activation} found instead."
    )

    # Pack top k ids and expert weights into a single int32 tensor, as
    # required by TRT-LLM
    packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
        torch.bfloat16
    ).view(torch.int16)

    if isinstance(x, tuple):
        # Hidden_states is the already quantized
        hidden_states_fp4, hidden_states_scale_linear_fp4 = x
    else:
        # Quantize input to FP4
        (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
            x, layer.a1_gscale, is_sf_swizzled_layout=False
        )

    # Call TRT-LLM FP4 block-scale MoE kernel
    out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
        topk_ids=packed_tensor,
        routing_bias=None,
        hidden_states=hidden_states_fp4,
        hidden_states_scale=hidden_states_scale_linear_fp4.view(
            torch.float8_e4m3fn
        ).flatten(),
        gemm1_weights=layer.w13_weight.data,
        gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
        gemm1_bias=None,
        gemm1_alpha=None,
        gemm1_beta=None,
        gemm1_clamp_limit=None,
        gemm2_weights=layer.w2_weight.data,
        gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
        gemm2_bias=None,
        output1_scale_scalar=layer.g1_scale_c.data,
        output1_scale_gate_scalar=layer.g1_alphas.data,
        output2_scale_scalar=layer.g2_alphas.data,
        num_experts=global_num_experts,
        top_k=top_k,
        n_group=0,
        topk_group=0,
        intermediate_size=layer.intermediate_size_per_partition,
        local_expert_offset=layer.ep_rank * layer.local_num_experts,
        local_num_experts=layer.local_num_experts,
        routed_scaling_factor=None,
        routing_method_type=1,
        do_finalize=True,
    )[0]

    return out

is_supported_config_trtllm

is_supported_config_trtllm(
    moe_config: FusedMoEConfig,
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
    activation_format: FusedMoEActivationFormat,
) -> tuple[bool, str | None]

This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def is_supported_config_trtllm(
    moe_config: FusedMoEConfig,
    weight_key: QuantKey | None,
    activation_key: QuantKey | None,
    activation_format: mk.FusedMoEActivationFormat,
) -> tuple[bool, str | None]:
    """
    This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
    """

    def _make_reason(reason: str) -> str:
        return f"kernel does not support {reason}"

    if not _supports_current_device():
        return False, _make_reason("current device")
    elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
        return False, _make_reason("no act_and_mul MLP layer")
    elif not _supports_activation(moe_config.activation):
        return False, _make_reason(f"{moe_config.activation} activation")
    elif not _supports_quant_scheme(weight_key, activation_key):
        return False, _make_reason("quantization scheme")
    elif not _supports_parallel_config(moe_config.moe_parallel_config):
        return False, _make_reason("parallel config")
    elif not _supports_routing_method(moe_config.routing_method):
        return False, _make_reason("routing method")
    elif activation_format != mk.FusedMoEActivationFormat.Standard:
        return False, _make_reason("activation format")

    return True, None

reorder_w1w3_to_w3w1

reorder_w1w3_to_w3w1(
    weight: Tensor, scale: Tensor, dim: int = -2
) -> tuple[Tensor, Tensor]

Re-order the concatenated [w1, w3] tensors to [w3, w1]

Source code in vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
def reorder_w1w3_to_w3w1(
    weight: torch.Tensor, scale: torch.Tensor, dim: int = -2
) -> tuple[torch.Tensor, torch.Tensor]:
    """Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`"""
    size = weight.size(dim)
    assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
    half = size // 2

    w1, w3 = weight.split(half, dim=dim)
    s1, s3 = scale.split(half, dim=dim)

    return (
        torch.cat([w3, w1], dim=dim).contiguous(),
        torch.cat([s3, s1], dim=dim).contiguous(),
    )