Skip to content

vllm.model_executor.layers.fused_moe.utils

_fp8_perm

_fp8_perm(m: Tensor, idx: Tensor) -> Tensor

A permutation routine that works on fp8 types.

Source code in vllm/model_executor/layers/fused_moe/utils.py
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
    """
    A permutation routine that works on fp8 types.
    """
    if torch.is_floating_point(m) and m.dtype.itemsize == 1:
        return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
    else:
        return m[idx, ...]

_fp8_quantize

_fp8_quantize(
    A: Tensor,
    A_scale: Tensor | None,
    per_act_token: bool,
    block_shape: list[int] | None = None,
) -> tuple[Tensor, Tensor]

Perform fp8 quantization on the inputs. If a block_shape is provided, the output will be blocked.

Source code in vllm/model_executor/layers/fused_moe/utils.py
def _fp8_quantize(
    A: torch.Tensor,
    A_scale: torch.Tensor | None,
    per_act_token: bool,
    block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Perform fp8 quantization on the inputs.  If a block_shape
    is provided, the output will be blocked.
    """
    if block_shape is None:
        # TODO(luka): use QuantFP8 custom op
        #  https://github.com/vllm-project/vllm/issues/20711
        A, A_scale = ops.scaled_fp8_quant(
            A, A_scale, use_per_token_if_dynamic=per_act_token
        )
    else:
        assert not per_act_token
        assert len(block_shape) == 2
        _, block_k = block_shape[0], block_shape[1]
        A, A_scale = per_token_group_quant_fp8(A, block_k)
        assert cdiv(A.size(-1), block_k) == A_scale.size(-1)

    return A, A_scale

_int8_quantize

_int8_quantize(
    A: Tensor,
    A_scale: Tensor | None,
    per_act_token: bool,
    block_shape: list[int] | None = None,
) -> tuple[Tensor, Tensor]

Perform int8 quantization on the inputs. If a block_shape is provided, the output will be blocked.

Source code in vllm/model_executor/layers/fused_moe/utils.py
def _int8_quantize(
    A: torch.Tensor,
    A_scale: torch.Tensor | None,
    per_act_token: bool,
    block_shape: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Perform int8 quantization on the inputs.  If a block_shape
    is provided, the output will be blocked.
    """

    # If weights are per-channel (per_channel_quant=True), then
    # activations apply per-token quantization. Otherwise, assume
    # activation tensor-wise fp8/int8 quantization, dynamic or static
    if block_shape is None:
        assert per_act_token, "int8 quantization only supports block or channel-wise"
        A, A_scale = per_token_quant_int8(A)
    else:
        assert not per_act_token
        assert len(block_shape) == 2
        _, block_k = block_shape[0], block_shape[1]
        A, A_scale = per_token_group_quant_int8(A, block_k)
        assert cdiv(A.size(-1), block_k) == A_scale.size(-1)

    return A, A_scale

_resize_cache

_resize_cache(x: Tensor, v: tuple[int, ...]) -> Tensor

Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches.

Source code in vllm/model_executor/layers/fused_moe/utils.py
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
    """
    Shrink the given tensor and apply the given view to it.  This is
    used to resize the intermediate fused_moe caches.
    """
    assert prod(v) <= x.numel(), (
        f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})"
    )  # CUDAGRAPH unfriendly?
    return x.flatten()[: prod(v)].view(*v)

apply_moe_activation

apply_moe_activation(
    activation: str, output: Tensor, input: Tensor
) -> Tensor

Apply MoE activation function.

For *_and_mul activations (silu, gelu, swigluoai): - Expects output.size(-1) * 2 == input.size(-1)

For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul): - Expects output.size(-1) == input.size(-1)

Source code in vllm/model_executor/layers/fused_moe/utils.py
def apply_moe_activation(
    activation: str,
    output: torch.Tensor,
    input: torch.Tensor,
) -> torch.Tensor:
    """
    Apply MoE activation function.

    For *_and_mul activations (silu, gelu, swigluoai):
        - Expects output.size(-1) * 2 == input.size(-1)

    For *_no_mul activations (silu_no_mul, gelu_no_mul, relu2_no_mul):
        - Expects output.size(-1) == input.size(-1)
    """
    is_no_mul = activation.endswith("_no_mul")
    if is_no_mul:
        assert output.size(-1) == input.size(-1), (
            f"{activation} expects equal sizes: {output.size(-1)} vs {input.size(-1)}"
        )
    else:
        assert output.size(-1) * 2 == input.size(-1), (
            f"{activation} expects 2x ratio: {output.size(-1) * 2} vs {input.size(-1)}"
        )

    # Activations with gated multiplication (gate × activation(up))
    if activation == "silu":
        torch.ops._C.silu_and_mul(output, input)
    elif activation == "gelu":
        torch.ops._C.gelu_and_mul(output, input)
    elif activation == "swigluoai":
        torch.ops._C.swigluoai_and_mul(output, input)
    elif activation == "swiglustep":
        from vllm.model_executor.layers.activation import swiglustep_and_mul_triton

        swiglustep_and_mul_triton(output, input)

    # Activations without gated multiplication
    elif activation == SILU_NO_MUL:
        output.copy_(F.silu(input))
    elif activation == GELU_NO_MUL:
        output.copy_(F.gelu(input))
    elif activation == RELU2_NO_MUL:
        F.relu(input, inplace=True)
        torch.square(input, out=output)
    else:
        raise ValueError(f"Unsupported FusedMoe activation: {activation}")

    return output

count_expert_num_tokens

count_expert_num_tokens(
    topk_ids: Tensor,
    num_local_experts: int,
    expert_map: Tensor | None,
) -> Tensor

Count the number to tokens assigned to each expert.

Parameters: - topk_ids (torch.Tensor): Tensor mapping each token to its list of experts. - num_local_experts (int): Number of experts in this rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices from the global expert space to the local expert space of the expert parallel shard.

Returns: A tensor of size num_local_experts, where tensor[i] holds the number of tokens assigned to the ith expert.

Source code in vllm/model_executor/layers/fused_moe/utils.py
def count_expert_num_tokens(
    topk_ids: torch.Tensor, num_local_experts: int, expert_map: torch.Tensor | None
) -> torch.Tensor:
    """
    Count the number to tokens assigned to each expert.

    Parameters:
    - topk_ids (torch.Tensor): Tensor mapping each token to its
    list of experts.
    - num_local_experts (int): Number of experts in this rank.
    - expert_map (Optional[torch.Tensor]):  A tensor mapping expert indices
    from the global expert space to the local expert space of the expert
    parallel shard.

    Returns:
    A tensor of size num_local_experts, where tensor[i] holds the number
    of tokens assigned to the ith expert.
    """
    assert topk_ids.dtype.is_signed, "The kernel uses -1 to represent invalid topk_ids"
    expert_num_tokens = torch.empty(
        (num_local_experts), device=topk_ids.device, dtype=torch.int32
    )

    grid = num_local_experts
    BLOCK_SIZE = min(topk_ids.numel(), 1024)
    BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)

    _count_expert_num_tokens[(grid,)](
        topk_ids,
        expert_num_tokens,
        num_local_experts,
        topk_ids.numel(),
        expert_map,
        HAS_EXPERT_MAP=expert_map is not None,
        BLOCK_SIZE=BLOCK_SIZE,
    )

    return expert_num_tokens