Skip to content

vllm.distributed.device_communicators.cpu_communicator

CpuCommunicator

Bases: DeviceCommunicatorBase

Source code in vllm/distributed/device_communicators/cpu_communicator.py
class CpuCommunicator(DeviceCommunicatorBase):
    def __init__(
        self,
        cpu_group: ProcessGroup,
        device: torch.device | None = None,
        device_group: ProcessGroup | None = None,
        unique_name: str = "",
    ):
        super().__init__(cpu_group, device, device_group, unique_name)
        self.dist_module = torch.distributed

        if (
            (
                current_platform.get_cpu_architecture() == CpuArchEnum.X86
                or current_platform.get_cpu_architecture() == CpuArchEnum.ARM
            )
            and hasattr(torch.ops._C, "init_shm_manager")
            and (unique_name.startswith("tp") or unique_name.startswith("pp"))
        ):
            self.dist_module = _CPUSHMDistributed(self)

        if self.use_all2all:
            if self.all2all_backend != "naive":  # type: ignore[has-type]
                logger.warning(
                    "`%s` all2all manager is not supported on CPU. "
                    "Falling back to `naive` all2all manager for CPU.",
                    self.all2all_backend,  # type: ignore[has-type]
                )
                self.all2all_backend = "naive"
            if self.all2all_backend == "naive":
                from .all2all import NaiveAll2AllManager

                self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
                logger.info("Using naive all2all manager.")

    def all_reduce(self, input_):
        self.dist_module.all_reduce(input_, group=self.device_group)
        return input_

    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
    ) -> torch.Tensor | None:
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()

        # Allocate output tensor.
        if self.rank_in_group == dst:
            gather_list = [torch.empty_like(input_) for _ in range(world_size)]
        else:
            gather_list = None

        # Gather.
        self.dist_module.gather(
            input_, gather_list, dst=self.ranks[dst], group=self.device_group
        )

        if self.rank_in_group == dst:
            output_tensor = torch.cat(gather_list, dim=dim)
        else:
            output_tensor = None
        return output_tensor

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
        # NOTE: we have to use concat-style all-gather here,
        # stack-style all-gather has compatibility issues with
        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
        output_size = (input_size[0] * self.world_size,) + input_size[1:]
        # Allocate output tensor.
        output_tensor = torch.empty(
            output_size, dtype=input_.dtype, device=input_.device
        )
        # All-gather.
        self.dist_module.all_gather_into_tensor(
            output_tensor, input_, group=self.device_group
        )

        # Reshape
        output_tensor = output_tensor.reshape((self.world_size,) + input_size)
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(
            input_size[:dim]
            + (self.world_size * input_size[dim],)
            + input_size[dim + 1 :]
        )
        return output_tensor

    def send_tensor_dict(
        self,
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int,
    ) -> None:
        return self.dist_module.send_tensor_dict(tensor_dict, dst)

    def recv_tensor_dict(
        self,
        src: int,
    ) -> dict[str, torch.Tensor | Any]:
        return self.dist_module.recv_tensor_dict(src)

    def dispatch_router_logits(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        is_sequence_parallel: bool = False,
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
        """
        Dispatch the hidden states and router logits to the appropriate device.
        This is a no-op in the base class.
        """

        assert self.all2all_manager is not None
        return self.all2all_manager.dispatch_router_logits(
            hidden_states,
            router_logits,
            is_sequence_parallel,
            extra_tensors,
        )

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        is_sequence_parallel: bool = False,
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
        """
        Dispatch the hidden states and topk weights/ids to the appropriate device.
        This is a no-op in the base class.
        """
        assert self.all2all_manager is not None
        return self.all2all_manager.dispatch(
            hidden_states,
            topk_weights,
            topk_ids,
            is_sequence_parallel,
            extra_tensors=extra_tensors,
        )

    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
        """
        Combine the hidden states and router logits from the appropriate device.
        This is a no-op in the base class.
        """
        assert self.all2all_manager is not None
        return self.all2all_manager.combine(
            hidden_states,
            is_sequence_parallel,
        )

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor

Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class.

Source code in vllm/distributed/device_communicators/cpu_communicator.py
def combine(
    self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
    """
    Combine the hidden states and router logits from the appropriate device.
    This is a no-op in the base class.
    """
    assert self.all2all_manager is not None
    return self.all2all_manager.combine(
        hidden_states,
        is_sequence_parallel,
    )

dispatch

dispatch(
    hidden_states: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[Tensor] | None = None,
) -> (
    tuple[Tensor, Tensor, Tensor]
    | tuple[Tensor, Tensor, Tensor, list[Tensor]]
)

Dispatch the hidden states and topk weights/ids to the appropriate device. This is a no-op in the base class.

Source code in vllm/distributed/device_communicators/cpu_communicator.py
def dispatch(
    self,
    hidden_states: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[torch.Tensor] | None = None,
) -> (
    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
    """
    Dispatch the hidden states and topk weights/ids to the appropriate device.
    This is a no-op in the base class.
    """
    assert self.all2all_manager is not None
    return self.all2all_manager.dispatch(
        hidden_states,
        topk_weights,
        topk_ids,
        is_sequence_parallel,
        extra_tensors=extra_tensors,
    )

dispatch_router_logits

dispatch_router_logits(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[Tensor] | None = None,
) -> (
    tuple[Tensor, Tensor]
    | tuple[Tensor, Tensor, list[Tensor]]
)

Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class.

Source code in vllm/distributed/device_communicators/cpu_communicator.py
def dispatch_router_logits(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[torch.Tensor] | None = None,
) -> (
    tuple[torch.Tensor, torch.Tensor]
    | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
    """
    Dispatch the hidden states and router logits to the appropriate device.
    This is a no-op in the base class.
    """

    assert self.all2all_manager is not None
    return self.all2all_manager.dispatch_router_logits(
        hidden_states,
        router_logits,
        is_sequence_parallel,
        extra_tensors,
    )

gather

gather(
    input_: Tensor, dst: int = 0, dim: int = -1
) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in vllm/distributed/device_communicators/cpu_communicator.py
def gather(
    self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
    )
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None

    # Gather.
    self.dist_module.gather(
        input_, gather_list, dst=self.ranks[dst], group=self.device_group
    )

    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor