Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector

OffloadPromMetrics

Bases: KVConnectorPromMetrics

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
class OffloadPromMetrics(KVConnectorPromMetrics):
    def __init__(
        self,
        vllm_config: VllmConfig,
        metric_types: dict[type[PromMetric], type[PromMetricT]],
        labelnames: list[str],
        per_engine_labelvalues: dict[int, list[object]],
    ):
        super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
        # (engine_idx, transfer_tupe) -> (metric with bounded labels)
        self.histogram_transfer_size: dict[tuple[int, str], PromMetricT] = {}
        self.counter_kv_bytes: dict[tuple[int, str], PromMetricT] = {}
        self.counter_kv_transfer_time: dict[tuple[int, str], PromMetricT] = {}
        buckets = [  # In bytes
            1e6,
            5e6,
            10e6,
            20e6,
            40e6,
            60e6,
            80e6,
            100e6,
            150e6,
            200e6,
        ]

        self._counter_kv_bytes = self._counter_cls(
            name="vllm:kv_offload_total_bytes",
            documentation="Number of bytes offloaded by KV connector",
            labelnames=labelnames + ["transfer_type"],
        )

        self._counter_kv_transfer_time = self._counter_cls(
            name="vllm:kv_offload_total_time",
            documentation="Total time measured by all KV offloading operations",
            labelnames=labelnames + ["transfer_type"],
        )

        self._histogram_transfer_size = self._histogram_cls(
            name="vllm:kv_offload_size",
            documentation="Histogram of KV offload transfer size, in bytes.",
            buckets=buckets[:],
            labelnames=labelnames + ["transfer_type"],
        )

    def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
        """
        Observe transfer statistics from the new data structure.
        transfer_stats_data is expected to be a dict where:
        - keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
        - values are lists of OffloadingOperationMetrics objects
        """

        for transfer_type, ops in transfer_stats_data.items():
            # Cache:
            if (engine_idx, transfer_type) not in self.histogram_transfer_size:
                self.histogram_transfer_size[(engine_idx, transfer_type)] = (
                    self._histogram_transfer_size.labels(
                        *(self.per_engine_labelvalues[engine_idx] + [transfer_type])
                    )
                )
                self.counter_kv_bytes[(engine_idx, transfer_type)] = (
                    self._counter_kv_bytes.labels(
                        *(self.per_engine_labelvalues[engine_idx] + [transfer_type])
                    )
                )
                self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
                    self._counter_kv_transfer_time.labels(
                        *(self.per_engine_labelvalues[engine_idx] + [transfer_type])
                    )
                )

            # Process ops:
            assert isinstance(ops, list)
            for op in ops:  # ops is a list of serialized OffloadingOperationMetrics
                assert isinstance(op, dict)
                # Observe size histogram
                self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
                    op["op_size"]
                )

                # Increment byte and time counters
                self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])

                self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
                    op["op_time"]
                )

observe

observe(
    transfer_stats_data: dict[str, Any], engine_idx: int = 0
)

Observe transfer statistics from the new data structure. transfer_stats_data is expected to be a dict where: - keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu") - values are lists of OffloadingOperationMetrics objects

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
    """
    Observe transfer statistics from the new data structure.
    transfer_stats_data is expected to be a dict where:
    - keys are transfer type strings (e.g., "cpu_to_gpu", "gpu_to_cpu")
    - values are lists of OffloadingOperationMetrics objects
    """

    for transfer_type, ops in transfer_stats_data.items():
        # Cache:
        if (engine_idx, transfer_type) not in self.histogram_transfer_size:
            self.histogram_transfer_size[(engine_idx, transfer_type)] = (
                self._histogram_transfer_size.labels(
                    *(self.per_engine_labelvalues[engine_idx] + [transfer_type])
                )
            )
            self.counter_kv_bytes[(engine_idx, transfer_type)] = (
                self._counter_kv_bytes.labels(
                    *(self.per_engine_labelvalues[engine_idx] + [transfer_type])
                )
            )
            self.counter_kv_transfer_time[(engine_idx, transfer_type)] = (
                self._counter_kv_transfer_time.labels(
                    *(self.per_engine_labelvalues[engine_idx] + [transfer_type])
                )
            )

        # Process ops:
        assert isinstance(ops, list)
        for op in ops:  # ops is a list of serialized OffloadingOperationMetrics
            assert isinstance(op, dict)
            # Observe size histogram
            self.histogram_transfer_size[(engine_idx, transfer_type)].observe(
                op["op_size"]
            )

            # Increment byte and time counters
            self.counter_kv_bytes[(engine_idx, transfer_type)].inc(op["op_size"])

            self.counter_kv_transfer_time[(engine_idx, transfer_type)].inc(
                op["op_time"]
            )

OffloadingConnectorScheduler

Implementation of Scheduler side methods

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
class OffloadingConnectorScheduler:
    """Implementation of Scheduler side methods"""

    def __init__(self, spec: OffloadingSpec):
        self.gpu_block_size = spec.gpu_block_size
        self.offloaded_block_size = spec.offloaded_block_size
        self.block_size_factor = self.offloaded_block_size // self.gpu_block_size
        self.manager: OffloadingManager = spec.get_manager()

        self._requests: dict[ReqId, Request] = {}
        # list of GPU block IDs per request
        self._request_block_ids: dict[ReqId, list[int]] = {}
        # requests to load for the current scheduler step
        self._reqs_to_load: dict[ReqId, TransferSpec] = {}
        # request blocks are stored in order
        # index of next block (of size offloaded_block_size) to offload
        self._next_stored_block_idx: dict[ReqId, int] = {}
        # if GPU prefix caching is enabled,
        # track loaded blocks to avoid redundant loads
        self._blocks_being_loaded: set[BlockHash] | None = (
            set() if spec.vllm_config.cache_config.enable_prefix_caching else None
        )

        # request ID -> set(block hashes being stored/load)
        self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
        self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)

    def _get_block_hashes(
        self,
        req: Request,
        start_idx: int = 0,
        end_idx: int | None = None,
    ) -> Iterable[BlockHash]:
        return islice(
            req.block_hashes,
            self.block_size_factor * start_idx + self.block_size_factor - 1,
            self.block_size_factor * end_idx if end_idx else None,
            self.block_size_factor,
        )

    def get_num_new_matched_tokens(
        self, request: Request, num_computed_tokens: int
    ) -> tuple[int | None, bool]:
        """
        Get number of new tokens that can be loaded beyond the
        num_computed_tokens.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            A tuple with the following elements:
                - The number of tokens that can be loaded beyond what is
                  already computed.
                  If None, it means that the connector needs more time to
                  determine the number of matched tokens, and the scheduler
                  should query for this request again later.
                - `True` if tokens will be loaded asynchronously
                  (between scheduler steps).
        """
        num_blocks = request.num_tokens // self.offloaded_block_size

        assert len(request.block_hashes) // self.block_size_factor == num_blocks
        block_hashes = self._get_block_hashes(request)

        self.manager.touch(block_hashes)

        full_block_tokens = self.offloaded_block_size * num_blocks
        if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
            # we can load less than a block, skip
            return 0, False

        start_block_idx = num_computed_tokens // self.offloaded_block_size
        hits = self.manager.lookup(
            self._get_block_hashes(request, start_idx=start_block_idx)
        )
        if hits is None:
            # indicates a lookup that should be tried later
            return None, False
        if hits == 0:
            return 0, False

        num_hit_tokens = (
            self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
        )
        logger.debug(
            "Request %s hit %s offloaded tokens after %s GPU hit tokens",
            request.request_id,
            num_hit_tokens,
            num_computed_tokens,
        )
        if num_hit_tokens < self.offloaded_block_size:
            return 0, False

        if self._blocks_being_loaded:
            block_hashes = self._get_block_hashes(
                request, start_idx=start_block_idx, end_idx=start_block_idx + hits
            )

            if any(
                block_hash in self._blocks_being_loaded for block_hash in block_hashes
            ):
                # hit blocks are being loaded, delay request
                logger.debug(
                    "Delaying request %s since some of its blocks are already"
                    " being loaded",
                    request.request_id,
                )
                return None, False

        return num_hit_tokens, True

    def update_state_after_alloc(
        self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
    ):
        self._requests[request.request_id] = request
        # the block ids are updated in _get_reqs_to_store
        self._request_block_ids[request.request_id] = []

        if num_external_tokens == 0:
            return

        block_groups = blocks.get_block_ids()
        block_ids = block_groups[0]

        num_computed_gpu_blocks = sum(
            block.block_hash is not None for block in blocks.blocks[0]
        )
        num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
        full_block_tokens = num_computed_tokens + num_external_tokens
        assert full_block_tokens % self.offloaded_block_size == 0

        num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
        assert num_external_tokens == num_pending_gpu_blocks * self.gpu_block_size

        start_block_idx = num_computed_tokens // self.offloaded_block_size
        num_blocks = full_block_tokens // self.offloaded_block_size

        assert len(request.block_hashes) // self.block_size_factor >= num_blocks
        block_hashes = self._get_block_hashes(
            request, start_idx=start_block_idx, end_idx=num_blocks
        )

        src_spec = self.manager.prepare_load(block_hashes)
        dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])

        block_hashes = self._get_block_hashes(
            request, start_idx=start_block_idx, end_idx=num_blocks
        )

        self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
        req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
        req_blocks_being_loaded.update(block_hashes)
        self._next_stored_block_idx[request.request_id] = num_blocks

        if self._blocks_being_loaded is not None:
            self._blocks_being_loaded.update(req_blocks_being_loaded)

    def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
        reqs_to_store: dict[ReqId, TransferSpec] = {}
        # iterate over both new and cached requests
        for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
            if preempted:
                self._request_block_ids[req_id] = []

            if new_block_id_groups:
                new_block_ids = new_block_id_groups[0]
                self._request_block_ids[req_id] += new_block_ids

            block_ids = self._request_block_ids[req_id]

            req = self._requests[req_id]
            new_tokens = scheduler_output.num_scheduled_tokens[req_id]
            total_tokens = req.num_computed_tokens + new_tokens
            num_blocks = total_tokens // self.offloaded_block_size
            start_block_idx = self._next_stored_block_idx.get(req_id, 0)
            num_new_blocks = num_blocks - start_block_idx

            if num_new_blocks <= 0:
                continue

            # NOTE: In async scheduling, placeholders may temporarily make
            # len(req.block_hashes) < num_blocks * self.block_size_factor.

            new_block_hashes = self._get_block_hashes(
                req, start_idx=start_block_idx, end_idx=num_blocks
            )
            store_output = self.manager.prepare_store(new_block_hashes)
            if store_output is None:
                logger.warning(
                    "Request %s: cannot store %s blocks", req_id, num_new_blocks
                )
                continue

            self._next_stored_block_idx[req_id] = num_blocks

            if not store_output.block_hashes_to_store:
                continue
            block_hashes_to_store = set(store_output.block_hashes_to_store)

            block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
            self.manager.touch(block_hashes)

            new_block_hashes = self._get_block_hashes(
                req, start_idx=start_block_idx, end_idx=num_blocks
            )
            dst_spec = store_output.store_spec
            src_block_ids: list[int] = []
            for idx, blk_hash in enumerate(new_block_hashes):
                if blk_hash not in block_hashes_to_store:
                    continue
                offloaded_block_idx = start_block_idx + idx
                gpu_block_idx = offloaded_block_idx * self.block_size_factor
                for i in range(self.block_size_factor):
                    src_block_ids.append(block_ids[gpu_block_idx + i])
            src_spec = GPULoadStoreSpec(src_block_ids)

            reqs_to_store[req_id] = (src_spec, dst_spec)
            self._reqs_being_stored[req_id] |= block_hashes_to_store

            logger.debug(
                "Request %s offloading %s blocks starting from block #%d",
                req_id,
                len(block_hashes_to_store),
                start_block_idx,
            )

        return reqs_to_store

    def build_connector_meta(
        self, scheduler_output: SchedulerOutput
    ) -> KVConnectorMetadata:
        meta = OffloadingConnectorMetadata(
            reqs_to_load=self._reqs_to_load,
            reqs_to_store=self._get_reqs_to_store(scheduler_output),
        )
        self._reqs_to_load = {}

        # NOTE (orozery): we should move this logic to update_connector_output
        # once KVConnectorOutput allows us to report completed transfers
        for req_id in scheduler_output.preempted_req_ids or ():
            block_hashes = self._reqs_being_stored.get(req_id)
            if block_hashes:
                self.manager.complete_store(block_hashes)
                block_hashes.clear()

        return meta

    def update_connector_output(self, connector_output: KVConnectorOutput):
        """
        Update KVConnector state from worker-side connectors output.

        Args:
            connector_output (KVConnectorOutput): the worker-side
                connectors output.
        """
        for req_id in connector_output.finished_sending or []:
            block_hashes = self._reqs_being_stored.pop(req_id, None)
            if block_hashes:
                self.manager.complete_store(block_hashes)

        for req_id in connector_output.finished_recving or []:
            block_hashes = self._reqs_being_loaded.pop(req_id, None)
            if block_hashes:
                if self._blocks_being_loaded:
                    self._blocks_being_loaded.difference_update(block_hashes)
                self.manager.complete_load(block_hashes)

    def request_finished(
        self,
        request: Request,
        block_ids: list[int],
    ) -> tuple[bool, dict[str, Any] | None]:
        """
        Called when a request has finished, before its blocks are freed.

        Returns:
            True if the request is being saved/sent asynchronously and blocks
            should not be freed until the request_id is returned from
            get_finished().
            Optional KVTransferParams to be included in the request outputs
            returned by the engine.
        """
        req_id = request.request_id
        self._requests.pop(req_id, None)
        self._request_block_ids.pop(req_id, None)
        self._next_stored_block_idx.pop(req_id, None)

        request_being_stored = req_id in self._reqs_being_stored
        return request_being_stored, None

    def take_events(self) -> Iterable[KVCacheEvent]:
        """Take the KV cache events from the connector.

        Returns:
            A list of KV cache events.
        """
        for event in self.manager.take_events():
            if event.removed:
                yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
            else:
                yield BlockStored(
                    block_hashes=event.block_hashes,
                    parent_block_hash=None,
                    token_ids=[],
                    lora_id=None,
                    block_size=event.block_size,
                    medium=event.medium,
                    lora_name=None,
                )

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]

Get number of new tokens that can be loaded beyond the num_computed_tokens.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
tuple[int | None, bool]

A tuple with the following elements: - The number of tokens that can be loaded beyond what is already computed. If None, it means that the connector needs more time to determine the number of matched tokens, and the scheduler should query for this request again later. - True if tokens will be loaded asynchronously (between scheduler steps).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def get_num_new_matched_tokens(
    self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]:
    """
    Get number of new tokens that can be loaded beyond the
    num_computed_tokens.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        A tuple with the following elements:
            - The number of tokens that can be loaded beyond what is
              already computed.
              If None, it means that the connector needs more time to
              determine the number of matched tokens, and the scheduler
              should query for this request again later.
            - `True` if tokens will be loaded asynchronously
              (between scheduler steps).
    """
    num_blocks = request.num_tokens // self.offloaded_block_size

    assert len(request.block_hashes) // self.block_size_factor == num_blocks
    block_hashes = self._get_block_hashes(request)

    self.manager.touch(block_hashes)

    full_block_tokens = self.offloaded_block_size * num_blocks
    if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
        # we can load less than a block, skip
        return 0, False

    start_block_idx = num_computed_tokens // self.offloaded_block_size
    hits = self.manager.lookup(
        self._get_block_hashes(request, start_idx=start_block_idx)
    )
    if hits is None:
        # indicates a lookup that should be tried later
        return None, False
    if hits == 0:
        return 0, False

    num_hit_tokens = (
        self.offloaded_block_size * (start_block_idx + hits) - num_computed_tokens
    )
    logger.debug(
        "Request %s hit %s offloaded tokens after %s GPU hit tokens",
        request.request_id,
        num_hit_tokens,
        num_computed_tokens,
    )
    if num_hit_tokens < self.offloaded_block_size:
        return 0, False

    if self._blocks_being_loaded:
        block_hashes = self._get_block_hashes(
            request, start_idx=start_block_idx, end_idx=start_block_idx + hits
        )

        if any(
            block_hash in self._blocks_being_loaded for block_hash in block_hashes
        ):
            # hit blocks are being loaded, delay request
            logger.debug(
                "Delaying request %s since some of its blocks are already"
                " being loaded",
                request.request_id,
            )
            return None, False

    return num_hit_tokens, True

request_finished

request_finished(
    request: Request, block_ids: list[int]
) -> tuple[bool, dict[str, Any] | None]

Called when a request has finished, before its blocks are freed.

Returns:

Type Description
bool

True if the request is being saved/sent asynchronously and blocks

dict[str, Any] | None

should not be freed until the request_id is returned from

tuple[bool, dict[str, Any] | None]

get_finished().

tuple[bool, dict[str, Any] | None]

Optional KVTransferParams to be included in the request outputs

tuple[bool, dict[str, Any] | None]

returned by the engine.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def request_finished(
    self,
    request: Request,
    block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
    """
    Called when a request has finished, before its blocks are freed.

    Returns:
        True if the request is being saved/sent asynchronously and blocks
        should not be freed until the request_id is returned from
        get_finished().
        Optional KVTransferParams to be included in the request outputs
        returned by the engine.
    """
    req_id = request.request_id
    self._requests.pop(req_id, None)
    self._request_block_ids.pop(req_id, None)
    self._next_stored_block_idx.pop(req_id, None)

    request_being_stored = req_id in self._reqs_being_stored
    return request_being_stored, None

take_events

take_events() -> Iterable[KVCacheEvent]

Take the KV cache events from the connector.

Returns:

Type Description
Iterable[KVCacheEvent]

A list of KV cache events.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def take_events(self) -> Iterable[KVCacheEvent]:
    """Take the KV cache events from the connector.

    Returns:
        A list of KV cache events.
    """
    for event in self.manager.take_events():
        if event.removed:
            yield BlockRemoved(block_hashes=event.block_hashes, medium=event.medium)
        else:
            yield BlockStored(
                block_hashes=event.block_hashes,
                parent_block_hash=None,
                token_ids=[],
                lora_id=None,
                block_size=event.block_size,
                medium=event.medium,
                lora_name=None,
            )

update_connector_output

update_connector_output(
    connector_output: KVConnectorOutput,
)

Update KVConnector state from worker-side connectors output.

Parameters:

Name Type Description Default
connector_output KVConnectorOutput

the worker-side connectors output.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def update_connector_output(self, connector_output: KVConnectorOutput):
    """
    Update KVConnector state from worker-side connectors output.

    Args:
        connector_output (KVConnectorOutput): the worker-side
            connectors output.
    """
    for req_id in connector_output.finished_sending or []:
        block_hashes = self._reqs_being_stored.pop(req_id, None)
        if block_hashes:
            self.manager.complete_store(block_hashes)

    for req_id in connector_output.finished_recving or []:
        block_hashes = self._reqs_being_loaded.pop(req_id, None)
        if block_hashes:
            if self._blocks_being_loaded:
                self._blocks_being_loaded.difference_update(block_hashes)
            self.manager.complete_load(block_hashes)

OffloadingConnectorStats dataclass

Bases: KVConnectorStats

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
@dataclass
class OffloadingConnectorStats(KVConnectorStats):
    def __post_init__(self):
        if not self.data:
            # Empty container init, no data is passed in.
            self.reset()

    def reset(self):
        self.data: dict[str, list[OffloadingOperationMetrics]] = {}

    def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
        if not other.is_empty():
            for k, v in other.data.items():
                if k not in self.data:
                    self.data[k] = v
                else:
                    accumulator = self.data[k]
                    assert isinstance(accumulator, list)
                    accumulator.extend(v)
        return self

    def reduce(self) -> dict[str, int | float]:
        """
        Reduce the observations collected during a time interval to one or
        more representative values (eg avg/median/sum of the series).
        This is meant to be called by the logger to produce a summary of the
        stats for the last time interval.
        """
        return_dict: dict[str, int | float] = {}
        for transfer_type, ops_list in self.data.items():
            assert isinstance(ops_list, list)
            total_bytes = 0
            total_time = 0
            for op in ops_list:
                assert isinstance(op, dict)
                total_bytes += op["op_size"]
                total_time += op["op_time"]
            return_dict[f"{transfer_type}_total_bytes"] = total_bytes
            return_dict[f"{transfer_type}_total_time"] = total_time
        return return_dict

    def is_empty(self) -> bool:
        return not self.data

    def record_transfer(self, num_bytes: int, time: float, transfer_type: TransferType):
        src, dst = transfer_type
        transfer_type_key = src + "_to_" + dst
        op = OffloadingOperationMetrics(num_bytes, time)
        if transfer_type_key in self.data:
            self.data[transfer_type_key].append(op)
        else:
            self.data[transfer_type_key] = [op]

reduce

reduce() -> dict[str, int | float]

Reduce the observations collected during a time interval to one or more representative values (eg avg/median/sum of the series). This is meant to be called by the logger to produce a summary of the stats for the last time interval.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def reduce(self) -> dict[str, int | float]:
    """
    Reduce the observations collected during a time interval to one or
    more representative values (eg avg/median/sum of the series).
    This is meant to be called by the logger to produce a summary of the
    stats for the last time interval.
    """
    return_dict: dict[str, int | float] = {}
    for transfer_type, ops_list in self.data.items():
        assert isinstance(ops_list, list)
        total_bytes = 0
        total_time = 0
        for op in ops_list:
            assert isinstance(op, dict)
            total_bytes += op["op_size"]
            total_time += op["op_time"]
        return_dict[f"{transfer_type}_total_bytes"] = total_bytes
        return_dict[f"{transfer_type}_total_time"] = total_time
    return return_dict

OffloadingConnectorWorker

Implementation of Worker side methods

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
class OffloadingConnectorWorker:
    """Implementation of Worker side methods"""

    def __init__(self, spec: OffloadingSpec):
        self.spec = spec
        self.worker = OffloadingWorker()

        self._job_counter = 0

        self.kv_connector_stats = OffloadingConnectorStats()
        # req_id -> (job_id, store)
        self._jobs: dict[int, tuple[ReqId, bool]] = {}
        # req_id -> active job IDs
        self._load_job: dict[ReqId, int] = {}
        # req_id -> set(active job IDs)
        self._store_jobs = defaultdict[ReqId, set[int]](set)
        # list of store jobs pending submission (job_id, transfer_spec)
        self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []

        self._finished_reqs_waiting_for_store: set[ReqId] = set()

    def _generate_job_id(self) -> int:
        job_id = self._job_counter
        self._job_counter = job_id + 1
        return job_id

    def _register_handlers(
        self,
        kv_caches: dict[str, torch.Tensor],
        attn_backends: dict[str, type[AttentionBackend]],
    ):
        for src_cls, dst_cls, handler in self.spec.get_handlers(
            kv_caches, attn_backends
        ):
            self.worker.register_handler(src_cls, dst_cls, handler)

    def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
        layer_names = list(kv_caches.keys())
        layers = get_layers_from_vllm_config(
            self.spec.vllm_config, Attention, layer_names
        )
        attn_backends = {
            layer_name: layers[layer_name].get_attn_backend()
            for layer_name in layer_names
        }
        self._register_handlers(kv_caches, attn_backends)

    def register_cross_layers_kv_cache(
        self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
    ):
        cross_layer_name = "ALL_LAYERS"
        kv_caches = {cross_layer_name: kv_cache}
        attn_backends = {cross_layer_name: attn_backend}
        self._register_handlers(kv_caches, attn_backends)

    def handle_preemptions(self, preempted_req_ids: set[str]):
        for job_id, transfer_spec in self._unsubmitted_store_jobs:
            success = self.worker.transfer_async(job_id, transfer_spec)
            assert success
        self._unsubmitted_store_jobs.clear()

        for req_id in preempted_req_ids:
            job_ids = self._store_jobs.get(req_id)
            if job_ids:
                self.worker.wait(job_ids)

    def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
        for job_id, transfer_spec in self._unsubmitted_store_jobs:
            success = self.worker.transfer_async(job_id, transfer_spec)
            assert success
        self._unsubmitted_store_jobs.clear()

        for req_id, transfer_spec in metadata.reqs_to_load.items():
            job_id = self._generate_job_id()
            self._jobs[job_id] = (req_id, False)
            assert req_id not in self._load_job
            self._load_job[req_id] = job_id
            success = self.worker.transfer_async(job_id, transfer_spec)
            assert success

    def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
        for req_id, transfer_spec in metadata.reqs_to_store.items():
            job_id = self._generate_job_id()
            self._jobs[job_id] = (req_id, True)
            self._store_jobs[req_id].add(job_id)
            # NOTE(orozery): defer the store to the beginning of the next engine step,
            # so that offloading starts AFTER transfers related to token sampling,
            # thereby avoiding delays to token generation due to offloading.
            self._unsubmitted_store_jobs.append((job_id, transfer_spec))

    def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
        """
        Notifies worker-side connector ids of requests that have
        finished generating tokens.
        Returns a list of request IDs that finished loading or storing.

        Returns:
            ids of requests that have finished asynchronous transfer
            tuple of (sending/saving ids, recving/loading ids).
        """
        finished_sending = set()
        finished_recving = set()
        for transfer_result in self.worker.get_finished():
            # we currently do not support job failures
            job_id = transfer_result.job_id
            assert transfer_result.success
            req_id, store = self._jobs.pop(job_id)
            if (
                transfer_result.transfer_time
                and transfer_result.transfer_size is not None
                and transfer_result.transfer_type is not None
            ):
                self.kv_connector_stats.record_transfer(
                    num_bytes=transfer_result.transfer_size,
                    time=transfer_result.transfer_time,
                    transfer_type=transfer_result.transfer_type,
                )
            if store:
                req_jobs = self._store_jobs[req_id]
                req_jobs.remove(job_id)
                if req_jobs:
                    continue

                if req_id in self._finished_reqs_waiting_for_store:
                    self._finished_reqs_waiting_for_store.remove(req_id)
                    finished_sending.add(req_id)
                    del self._store_jobs[req_id]
            else:
                req_job = self._load_job[req_id]
                assert job_id == req_job
                del self._load_job[req_id]
                finished_recving.add(req_id)

        for req_id in finished_req_ids:
            pending_req_jobs = self._store_jobs.get(req_id)
            if pending_req_jobs:
                self._finished_reqs_waiting_for_store.add(req_id)
            elif pending_req_jobs is not None:
                finished_sending.add(req_id)
                del self._store_jobs[req_id]

        return finished_sending, finished_recving

    def get_kv_connector_stats(self) -> KVConnectorStats | None:
        """
        Get the KV transfer stats for the connector.
        """

        if self.kv_connector_stats.is_empty():
            return None
        # Clear stats for next iteration
        kv_connector_stats = self.kv_connector_stats
        self.kv_connector_stats = OffloadingConnectorStats()
        return kv_connector_stats

get_finished

get_finished(
    finished_req_ids: set[str],
) -> tuple[set[str], set[str]]

Notifies worker-side connector ids of requests that have finished generating tokens. Returns a list of request IDs that finished loading or storing.

Returns:

Type Description
set[str]

ids of requests that have finished asynchronous transfer

set[str]

tuple of (sending/saving ids, recving/loading ids).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
    """
    Notifies worker-side connector ids of requests that have
    finished generating tokens.
    Returns a list of request IDs that finished loading or storing.

    Returns:
        ids of requests that have finished asynchronous transfer
        tuple of (sending/saving ids, recving/loading ids).
    """
    finished_sending = set()
    finished_recving = set()
    for transfer_result in self.worker.get_finished():
        # we currently do not support job failures
        job_id = transfer_result.job_id
        assert transfer_result.success
        req_id, store = self._jobs.pop(job_id)
        if (
            transfer_result.transfer_time
            and transfer_result.transfer_size is not None
            and transfer_result.transfer_type is not None
        ):
            self.kv_connector_stats.record_transfer(
                num_bytes=transfer_result.transfer_size,
                time=transfer_result.transfer_time,
                transfer_type=transfer_result.transfer_type,
            )
        if store:
            req_jobs = self._store_jobs[req_id]
            req_jobs.remove(job_id)
            if req_jobs:
                continue

            if req_id in self._finished_reqs_waiting_for_store:
                self._finished_reqs_waiting_for_store.remove(req_id)
                finished_sending.add(req_id)
                del self._store_jobs[req_id]
        else:
            req_job = self._load_job[req_id]
            assert job_id == req_job
            del self._load_job[req_id]
            finished_recving.add(req_id)

    for req_id in finished_req_ids:
        pending_req_jobs = self._store_jobs.get(req_id)
        if pending_req_jobs:
            self._finished_reqs_waiting_for_store.add(req_id)
        elif pending_req_jobs is not None:
            finished_sending.add(req_id)
            del self._store_jobs[req_id]

    return finished_sending, finished_recving

get_kv_connector_stats

get_kv_connector_stats() -> KVConnectorStats | None

Get the KV transfer stats for the connector.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
def get_kv_connector_stats(self) -> KVConnectorStats | None:
    """
    Get the KV transfer stats for the connector.
    """

    if self.kv_connector_stats.is_empty():
        return None
    # Clear stats for next iteration
    kv_connector_stats = self.kv_connector_stats
    self.kv_connector_stats = OffloadingConnectorStats()
    return kv_connector_stats