Skip to content

vllm.model_executor.utils

Utils for model executor.

replace_parameter

replace_parameter(
    layer: Module, param_name: str, new_data: Tensor | None
)

Replace a parameter of a layer while maintaining the ability to reload the weight. Called within implementations of the process_weights_after_loading method.

This function should not be called on weights which are tied/shared

Parameters:

Name Type Description Default
layer Module

Layer containing parameter to replace

required
param_name str

Name of parameter to replace

required
new_data Tensor | None

New data of the new parameter, or None to set the parameter to None

required
Source code in vllm/model_executor/utils.py
def replace_parameter(
    layer: torch.nn.Module, param_name: str, new_data: torch.Tensor | None
):
    """
    Replace a parameter of a layer while maintaining the ability to reload the weight.
    Called within implementations of the `process_weights_after_loading` method.

    This function should not be called on weights which are tied/shared

    Args:
        layer: Layer containing parameter to replace
        param_name: Name of parameter to replace
        new_data: New data of the new parameter, or None to set the parameter to None
    """
    # should not be used on a tied/shared param

    # If new_data is None, set the parameter to None
    if new_data is None:
        setattr(layer, param_name, None)
        return

    if isinstance(new_data, torch.nn.Parameter):
        new_data = new_data.data
    new_param = torch.nn.Parameter(new_data, requires_grad=False)

    old_param: torch.nn.Parameter | None = getattr(layer, param_name, None)
    if old_param is not None and hasattr(old_param, "weight_loader"):
        weight_loader = old_param.weight_loader
        set_weight_attrs(new_param, {"weight_loader": weight_loader})

    setattr(layer, param_name, new_param)

set_weight_attrs

set_weight_attrs(
    weight: Tensor, weight_attrs: dict[str, Any] | None
)

Set attributes on a weight tensor.

This method is used to set attributes on a weight tensor. This method will not overwrite existing attributes.

Parameters:

Name Type Description Default
weight Tensor

The weight tensor.

required
weight_attrs dict[str, Any] | None

A dictionary of attributes to set on the weight tensor.

required
Source code in vllm/model_executor/utils.py
def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: dict[str, Any] | None,
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"

        # NOTE(woosuk): During weight loading, we often do something like:
        # narrowed_tensor = param.data.narrow(0, offset, len)
        # narrowed_tensor.copy_(real_weight)
        # expecting narrowed_tensor and param.data to share the same storage.
        # However, on TPUs, narrowed_tensor will lazily propagate to the base
        # tensor, which is param.data, leading to the redundant memory usage.
        # This sometimes causes OOM errors during model loading. To avoid this,
        # we sync the param tensor after its weight loader is called.
        # TODO(woosuk): Remove this hack once we have a better solution.
        from vllm.platforms import current_platform

        if current_platform.use_sync_weight_loader() and key == "weight_loader":
            value = current_platform.make_synced_weight_loader(value)
        setattr(weight, key, value)