Skip to content

vllm.model_executor.models.llama4

Inference-only LLaMA model compatible with HuggingFace weights.

Llama4Model

Bases: LlamaModel

Source code in vllm/model_executor/models/llama4.py
@support_torch_compile
class Llama4Model(LlamaModel):
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
    ):
        self.num_experts = vllm_config.model_config.hf_config.num_local_experts
        self.n_redundant_experts = (
            vllm_config.parallel_config.eplb_config.num_redundant_experts
        )
        super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)

    def load_moe_expert_weights(
        self,
        name: str,
        loaded_weight: torch.Tensor,
        params_dict: dict[str, nn.Parameter],
        loaded_params: set[str],
        expert_params_mapping: list[tuple[str, str, int, str]],
        fused: bool = True,
    ) -> bool:
        """
        Load MoE expert weights.

        Args:
            name: The name of the weight to load.
            loaded_weight: The weight to load.
            params_dict: The dictionary of module parameters.
            loaded_params: The set of already loaded parameters.
            expert_params_mapping: The mapping of expert parameters. Must be
                generated by SharedFusedMoE.make_expert_params_mapping().
            fused: Whether the expert weights are fused into a single weight
                tensor or are separate weight tensors for each expert.
                When fused is True, loaded_weight should have shape of:
                [num_experts, hidden_in, hidden_out] for gate/up/down proj and
                [hidden_out, hidden_in] for the others like router.
                When fused is False, loaded_weight should have shape of:
                [hidden_out, hidden_in].

        Returns:
            True if loaded_weight is one of MoE weights and the MoE expert
            weights are loaded successfully, False otherwise.
        """

        # Whether the MoE expert weights are loaded successfully.
        expert_param_loaded = False

        # If fused is True, the loaded weight is in the layout of:
        # [num_experts, hidden_in, hidden_out], so we must transpose the last
        # two dimensions to match the expected layout of the parameters.
        if fused and loaded_weight.ndim == 3:
            loaded_weight = loaded_weight.transpose(-1, -2)

            # If the gate_proj and up_proj weights are fused into a single
            # weight tensor, we need to split the weight tensor into a tuple
            # of two weight tensors along the hidden_out dimension.
            if "experts.gate_up_proj" in name:
                loaded_weight = loaded_weight.chunk(2, dim=-2)

        # Iterate over all the expert parameters and load the weights if we find
        # a match in weight name.
        for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
            # Get a view of the loaded_weight to avoid modifying the original
            # one across iterations.
            new_loaded_weight = loaded_weight

            # If expert weights are fused into a single weight tensor, remove
            # the expert index from the expected weight name.
            if fused:
                # The string between e_str and proj_str is the expert index.
                e_str, _, proj_str, _ = weight_name.split(".")
                weight_name = f"{e_str}.{proj_str}"
                param_name = f"{param_name}weight"

            # Skip if the current weight is not one of the MoE weights.
            if weight_name not in name:
                continue

            # Replace the weight name with the parameter name.
            full_param_name = name.replace(weight_name, param_name)

            # Skip if the current weight corresponds to a parameter that
            # does not exist on the current PP (pipeline parallel) rank.
            if is_pp_missing_parameter(name, self):
                continue

            # Skip if the current weight is for the bias.
            if (
                name.endswith(".bias") or name.endswith("_bias")
            ) and name not in params_dict:
                continue

            param = params_dict[full_param_name]
            weight_loader = param.weight_loader

            if fused:
                # If the parameter is for w13 together, the corresponding weight
                # will be a tuple, so we must select the correct weight
                # depending on the shard id, which is either "w1" or "w3".
                if "w13" in full_param_name:
                    assert shard_id in ["w1", "w3"]
                    shard_idx = 0 if shard_id == "w1" else 1
                    new_loaded_weight = new_loaded_weight[shard_idx]

                # If EP (expert parallel) is enabled, update expert_id to the
                # starting expert index for the current EP rank and extract the
                # corresponding expert weights.
                layer_idx = extract_layer_index(name)
                expert_map = self.layers[layer_idx].feed_forward.experts.expert_map
                if expert_map is not None:
                    local_expert_indices = (
                        (expert_map != -1)
                        .nonzero()
                        .flatten()
                        .to(new_loaded_weight.device)
                    )
                    # Workaround for FP8 CPU indexing on older PyTorch:
                    # https://github.com/vllm-project/vllm/issues/32862
                    is_fp8_dtype = new_loaded_weight.dtype == (
                        current_platform.fp8_dtype()
                    ) or (
                        new_loaded_weight.dtype.is_floating_point
                        and new_loaded_weight.element_size() == 1
                    )
                    if (
                        new_loaded_weight.device.type == "cpu"
                        and is_fp8_dtype
                        and not is_torch_equal_or_newer("2.11.0")
                    ):
                        # PyTorch < 2.11 doesn't support CPU float8 indexing.
                        new_loaded_weight = new_loaded_weight.to(torch.float16)[
                            local_expert_indices
                        ].to(new_loaded_weight.dtype)
                    else:
                        new_loaded_weight = new_loaded_weight[local_expert_indices]
                    expert_id = local_expert_indices[0].item()
            else:
                # TODO: add EP support for non fused weights
                pass

            # Load the weight into the module parameter with corresponding
            # shard id and expert id.
            weight_loader(
                param,
                new_loaded_weight,
                full_param_name,
                shard_id=shard_id,
                expert_id=expert_id,
            )
            loaded_params.add(full_param_name)
            expert_param_loaded = True

        return expert_param_loaded

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        # Name mapping from the parameter name to the shard name and
        # corresponding shard id.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        # Indicate whether the expert weights are fused into a single weight
        # tensor.
        fused_experts_params = False
        # Expert parameter mapping for the case where the expert weights are
        # not fused into a single weight tensor.
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.num_experts,
            num_redundant_experts=self.n_redundant_experts,
        )
        # Expert parameter mapping for the case where the expert weights are
        # fused into a single weight tensor.
        expert_params_mapping_fused = SharedFusedMoE.make_expert_params_mapping(
            self,
            ckpt_gate_proj_name="gate_up_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="gate_up_proj",
            num_experts=1,
        )
        # All the module parameters.
        params_dict = dict(self.named_parameters())
        # The module parameters that have been loaded.
        loaded_params: set[str] = set()

        # Iterate over all the weights and load them into module parameters.
        for name, loaded_weight in weights:
            # If the name contains "experts.gate_up_proj" or "experts.down_proj"
            # without the expert indices, it means the expert weights are fused
            # into a single weight tensor across all experts.
            if "experts.gate_up_proj" in name or "experts.down_proj" in name:
                fused_experts_params = True
                expert_params_mapping = expert_params_mapping_fused

            # If kv cache quantization scales exist and the weight name
            # corresponds to one of the kv cache quantization scales, load
            # them.
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

            # Iterate over stacked_params_mapping to check if the current weight
            # is one of the stacked parameters. If so, load the weight with the
            # corresponding shard id. Note that MoE weights are handled
            # separately in the else block.
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip if the current weight is not one of the stacked
                # parameters or if the current weight is a MoE weight.
                if weight_name not in name or "experts" in name:
                    continue

                # For ModelOpt checkpoints, we need to rename the self_attn
                # weight/weight_scale names except for kv cache scales.
                if not (
                    name.endswith((".k_scale", ".v_scale")) and "self_attn" in name
                ):
                    name = name.replace(weight_name, param_name)

                # Skip if the current weight corresponds to a parameter that
                # does not exist on the current PP (pipeline parallel) rank.
                if is_pp_missing_parameter(name, self):
                    continue

                # Remap kv cache scale names for ModelOpt checkpoints.
                # TODO: ModelOpt should implement get_cache_scale() such that
                #       kv cache scale name remapping can be done there.
                if name.endswith("scale"):
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                # Load the weight into the module parameter with corresponding
                # shard id and exit the for loop and the else block.
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)

                if weight_loader == default_weight_loader:
                    weight_loader(param, loaded_weight)
                else:
                    weight_loader(param, loaded_weight, shard_id)

                loaded_params.add(name)
                break

            # Handle normal (non-stacked) weights and MoE weights.
            else:
                # First, try to load MoE weights using load_moe_expert_weights.
                # If successful, move on to next loaded weight.
                if self.load_moe_expert_weights(
                    name,
                    loaded_weight,
                    params_dict,
                    loaded_params,
                    expert_params_mapping,
                    fused=fused_experts_params,
                ):
                    continue

                # Skip if the current weight corresponds to a parameter that
                # does not exist on the current PP (pipeline parallel) rank.
                if is_pp_missing_parameter(name, self):
                    continue

                # Handle flat expert scale parameters that don't match
                # per-expert patterns, i.e. one weight scale tensor for all
                # experts.
                scale_names = [
                    "w13_input_scale",
                    "w13_weight_scale",
                    "w2_input_scale",
                    "w2_weight_scale",
                ]
                if "experts." in name and any(
                    scale_name in name for scale_name in scale_names
                ):
                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )

                    # If weight loader supports special moe loading, use it to
                    # avoid expensive runtime reflection
                    if getattr(weight_loader, "supports_moe_loading", False):
                        # Map the weight name to the corresponding shard id.
                        shard_id = "w2" if "w2_" in name else "w1"

                        # Transpose if weight scales are FP8 block scales with
                        # three dimensions:
                        # [num_experts, hidden_in, hidden_out].
                        if (
                            name.endswith("weight_scale")
                            and loaded_weight.dtype == torch.float8_e4m3fn
                            and loaded_weight.ndim == 3
                        ):
                            loaded_weight = loaded_weight.transpose(-1, -2)

                        # Load the weight into the module parameter with
                        # corresponding shard id and expert id.
                        weight_loader(
                            param, loaded_weight, name, shard_id=shard_id, expert_id=0
                        )

                    else:
                        # Regular weight loader (handles both
                        # param.weight_loader and default_weight_loader)
                        weight_loader(param, loaded_weight)

                    loaded_params.add(name)
                    continue

                # Handle normal (non-stacked, non-MoE) weights.
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
                loaded_params.add(name)

        # Finally, return the set of loaded parameters.
        return loaded_params

load_moe_expert_weights

load_moe_expert_weights(
    name: str,
    loaded_weight: Tensor,
    params_dict: dict[str, Parameter],
    loaded_params: set[str],
    expert_params_mapping: list[tuple[str, str, int, str]],
    fused: bool = True,
) -> bool

Load MoE expert weights.

Parameters:

Name Type Description Default
name str

The name of the weight to load.

required
loaded_weight Tensor

The weight to load.

required
params_dict dict[str, Parameter]

The dictionary of module parameters.

required
loaded_params set[str]

The set of already loaded parameters.

required
expert_params_mapping list[tuple[str, str, int, str]]

The mapping of expert parameters. Must be generated by SharedFusedMoE.make_expert_params_mapping().

required
fused bool

Whether the expert weights are fused into a single weight tensor or are separate weight tensors for each expert. When fused is True, loaded_weight should have shape of: [num_experts, hidden_in, hidden_out] for gate/up/down proj and [hidden_out, hidden_in] for the others like router. When fused is False, loaded_weight should have shape of: [hidden_out, hidden_in].

True

Returns:

Type Description
bool

True if loaded_weight is one of MoE weights and the MoE expert

bool

weights are loaded successfully, False otherwise.

Source code in vllm/model_executor/models/llama4.py
def load_moe_expert_weights(
    self,
    name: str,
    loaded_weight: torch.Tensor,
    params_dict: dict[str, nn.Parameter],
    loaded_params: set[str],
    expert_params_mapping: list[tuple[str, str, int, str]],
    fused: bool = True,
) -> bool:
    """
    Load MoE expert weights.

    Args:
        name: The name of the weight to load.
        loaded_weight: The weight to load.
        params_dict: The dictionary of module parameters.
        loaded_params: The set of already loaded parameters.
        expert_params_mapping: The mapping of expert parameters. Must be
            generated by SharedFusedMoE.make_expert_params_mapping().
        fused: Whether the expert weights are fused into a single weight
            tensor or are separate weight tensors for each expert.
            When fused is True, loaded_weight should have shape of:
            [num_experts, hidden_in, hidden_out] for gate/up/down proj and
            [hidden_out, hidden_in] for the others like router.
            When fused is False, loaded_weight should have shape of:
            [hidden_out, hidden_in].

    Returns:
        True if loaded_weight is one of MoE weights and the MoE expert
        weights are loaded successfully, False otherwise.
    """

    # Whether the MoE expert weights are loaded successfully.
    expert_param_loaded = False

    # If fused is True, the loaded weight is in the layout of:
    # [num_experts, hidden_in, hidden_out], so we must transpose the last
    # two dimensions to match the expected layout of the parameters.
    if fused and loaded_weight.ndim == 3:
        loaded_weight = loaded_weight.transpose(-1, -2)

        # If the gate_proj and up_proj weights are fused into a single
        # weight tensor, we need to split the weight tensor into a tuple
        # of two weight tensors along the hidden_out dimension.
        if "experts.gate_up_proj" in name:
            loaded_weight = loaded_weight.chunk(2, dim=-2)

    # Iterate over all the expert parameters and load the weights if we find
    # a match in weight name.
    for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
        # Get a view of the loaded_weight to avoid modifying the original
        # one across iterations.
        new_loaded_weight = loaded_weight

        # If expert weights are fused into a single weight tensor, remove
        # the expert index from the expected weight name.
        if fused:
            # The string between e_str and proj_str is the expert index.
            e_str, _, proj_str, _ = weight_name.split(".")
            weight_name = f"{e_str}.{proj_str}"
            param_name = f"{param_name}weight"

        # Skip if the current weight is not one of the MoE weights.
        if weight_name not in name:
            continue

        # Replace the weight name with the parameter name.
        full_param_name = name.replace(weight_name, param_name)

        # Skip if the current weight corresponds to a parameter that
        # does not exist on the current PP (pipeline parallel) rank.
        if is_pp_missing_parameter(name, self):
            continue

        # Skip if the current weight is for the bias.
        if (
            name.endswith(".bias") or name.endswith("_bias")
        ) and name not in params_dict:
            continue

        param = params_dict[full_param_name]
        weight_loader = param.weight_loader

        if fused:
            # If the parameter is for w13 together, the corresponding weight
            # will be a tuple, so we must select the correct weight
            # depending on the shard id, which is either "w1" or "w3".
            if "w13" in full_param_name:
                assert shard_id in ["w1", "w3"]
                shard_idx = 0 if shard_id == "w1" else 1
                new_loaded_weight = new_loaded_weight[shard_idx]

            # If EP (expert parallel) is enabled, update expert_id to the
            # starting expert index for the current EP rank and extract the
            # corresponding expert weights.
            layer_idx = extract_layer_index(name)
            expert_map = self.layers[layer_idx].feed_forward.experts.expert_map
            if expert_map is not None:
                local_expert_indices = (
                    (expert_map != -1)
                    .nonzero()
                    .flatten()
                    .to(new_loaded_weight.device)
                )
                # Workaround for FP8 CPU indexing on older PyTorch:
                # https://github.com/vllm-project/vllm/issues/32862
                is_fp8_dtype = new_loaded_weight.dtype == (
                    current_platform.fp8_dtype()
                ) or (
                    new_loaded_weight.dtype.is_floating_point
                    and new_loaded_weight.element_size() == 1
                )
                if (
                    new_loaded_weight.device.type == "cpu"
                    and is_fp8_dtype
                    and not is_torch_equal_or_newer("2.11.0")
                ):
                    # PyTorch < 2.11 doesn't support CPU float8 indexing.
                    new_loaded_weight = new_loaded_weight.to(torch.float16)[
                        local_expert_indices
                    ].to(new_loaded_weight.dtype)
                else:
                    new_loaded_weight = new_loaded_weight[local_expert_indices]
                expert_id = local_expert_indices[0].item()
        else:
            # TODO: add EP support for non fused weights
            pass

        # Load the weight into the module parameter with corresponding
        # shard id and expert id.
        weight_loader(
            param,
            new_loaded_weight,
            full_param_name,
            shard_id=shard_id,
            expert_id=expert_id,
        )
        loaded_params.add(full_param_name)
        expert_param_loaded = True

    return expert_param_loaded