@dataclass
class LoRAKernelMeta:
token_lora_mapping: torch.Tensor
token_indices_sorted_by_lora_ids: torch.Tensor
active_lora_ids: torch.Tensor
num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
# Stored as a Python int to avoid GPU->CPU sync during forward pass
num_active_loras: int = 0
# Captured LoRA counts for cudagraph specialization (sorted list).
# When specialize_active_lora is enabled, num_active_loras is rounded up
# to the nearest value in this list to match cudagraph capture keys.
# Empty list means no specialization (use actual count).
captured_lora_counts: list[int] = field(default_factory=list)
@staticmethod
def make(
max_loras: int,
max_num_tokens: int,
device: torch.device | str,
captured_lora_counts: list[int] | None = None,
) -> "LoRAKernelMeta":
token_lora_mapping = torch.empty(
max_num_tokens, dtype=torch.int32, device=device
)
token_indices_sorted_by_lora_ids = torch.empty(
max_num_tokens, dtype=torch.int32, device=device
)
# +1 because "no-lora" is also a possibility
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
# is a possibility.
active_lora_ids = torch.empty(max_loras + 1, dtype=torch.int32, device=device)
# using running example, [3, 10, 5, 2] is a possibility.
num_tokens_per_lora = torch.zeros(
max_loras + 1, dtype=torch.int32, device=device
)
# +2 for this because, the first index is always 0.
# using running example, lora_token_start_loc
# is [0, 3, 13, 18, 20].
lora_token_start_loc = torch.zeros(
max_loras + 2, dtype=torch.int32, device=device
)
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu,
captured_lora_counts=sorted(captured_lora_counts)
if captured_lora_counts
else [],
)
def _reset(self):
self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0
self.captured_lora_counts = []
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
Prepare kernel metadata tensors for the current forward pass.
Args:
token_lora_mapping (torch.Tensor): Tensor containing lora indices
for each input token.
"""
self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0)
# copy token lora mapping
self.token_lora_mapping[:num_tokens].copy_(
token_lora_mapping, non_blocking=True
)
# token_indices_sorted_by_lora_ids
_, token_indices_sorted_by_lora_ids = torch.sort(
token_lora_mapping, stable=True
)
# start gpu transfer
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
token_indices_sorted_by_lora_ids, non_blocking=True
)
# active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(
token_lora_mapping, sorted=True, return_counts=True
)
self.active_lora_ids[: lora_ids.size(0)].copy_(lora_ids, non_blocking=True)
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
num_tokens_per_lora, non_blocking=True
)
self.num_active_loras = lora_ids.size(0)
# Round up num_active_loras to match cudagraph capture keys.
# This ensures the kernel grid dimension matches the captured graph.
if self.captured_lora_counts and self.num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
if idx < len(self.captured_lora_counts):
self.num_active_loras = self.captured_lora_counts[idx]
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_(
lora_token_start_loc, non_blocking=True
)
def meta_args(
self,
token_nums: int,
specialize_active_lora: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
int,
]:
"""
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
metadata required by the kernel, in order, as a tuple, so it can be
unpacked directly during the lora_shrink/lora_expand function call.
Args:
token_nums (int): Number of input tokens in the current forward
pass of the kernel.
"""
max_loras = self.active_lora_ids.size(0) - 1
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
self.num_active_loras if specialize_active_lora else max_loras + 1,
)