Skip to content

vllm.lora.layers

Modules:

Name Description
base
column_parallel_linear
fused_moe
logits_processor
replicated_linear
row_parallel_linear
utils

BaseLayerWithLoRA

Bases: Module

Source code in vllm/lora/layers/base.py
class BaseLayerWithLoRA(nn.Module):
    def slice_lora_a(
        self, lora_a: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora a if splitting for tensor parallelism."""
        ...

    def slice_lora_b(
        self, lora_b: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora b if splitting with tensor parallelism."""
        ...

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
        punica_wrapper,
    ):
        self.punica_wrapper: PunicaWrapperBase = punica_wrapper

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/base.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    raise NotImplementedError

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/base.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""
    ...

reset_lora

reset_lora(index: int)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/base.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    ...

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/base.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    ...

slice_lora_a

slice_lora_a(
    lora_a: Tensor | list[Tensor | None],
) -> Tensor | list[Tensor | None]

Slice lora a if splitting for tensor parallelism.

Source code in vllm/lora/layers/base.py
def slice_lora_a(
    self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora a if splitting for tensor parallelism."""
    ...

slice_lora_b

slice_lora_b(
    lora_b: Tensor | list[Tensor | None],
) -> Tensor | list[Tensor | None]

Slice lora b if splitting with tensor parallelism.

Source code in vllm/lora/layers/base.py
def slice_lora_b(
    self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora b if splitting with tensor parallelism."""
    ...

ColumnParallelLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

LoRA on top of ColumnParallelLinear layer. LoRA B is sliced for tensor parallelism. There are two types for the base_layer: 1. ColumnParallelLinear, e.g.dense_h_to_4h in FalconForCausalLM. 2. MergedColumnParallelLinear, e.g.gate_up_proj in Phi3ForCausalLM.

Source code in vllm/lora/layers/column_parallel_linear.py
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
    """

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__(base_layer)
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear
        self.output_size = self.base_layer.output_size_per_partition
        # There is only one LoRA layer
        self.n_slices = 1

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            shard_size = self.output_size // 2
            offset = lora_b.shape[0] // 2

            left_weight = lora_b[
                self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, :
            ]
            right_weight = lora_b[
                offset + self.tp_rank * shard_size : offset
                + (self.tp_rank + 1) * shard_size,
                :,
            ]
            lora_b = torch.cat([left_weight, right_weight], dim=0)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            shard_size = self.output_size
            start_idx = self.tp_rank * shard_size
            end_idx = (self.tp_rank + 1) * shard_size
            lora_b = lora_b[start_idx:end_idx, :]
        return lora_b

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

        # Matrix multiply.
        output_parallel = self.apply(input_, bias)
        if self.base_layer.gather_output and self.tp_size > 1:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel

        if not self.base_layer.return_bias:
            return output

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        if type(source_layer) is ColumnParallelLinear:
            return True
        if type(source_layer) is MergedColumnParallelLinear:
            if len(packed_modules_list) != 1:
                return False
            # Exclude layers with 3+ output sizes - those are handled by
            # MergedColumnParallelLinearVariableSliceWithLoRA since this
            # class's slice_lora_b assumes exactly 2 slices.
            return not (
                hasattr(source_layer, "output_sizes")
                and len(source_layer.output_sizes) >= 3
            )
        return False

forward

forward(
    input_: Tensor,
) -> Tensor | tuple[Tensor, Tensor | None]

Forward of ColumnParallelLinear

Parameters:

Name Type Description Default
input_ Tensor

Tensor whose last dimension is input_size.

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor | None]
  • output
Tensor | tuple[Tensor, Tensor | None]
  • bias
Source code in vllm/lora/layers/column_parallel_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of ColumnParallelLinear

    Args:
        input_: Tensor whose last dimension is `input_size`.

    Returns:
        - output
        - bias
    """
    bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

    # Matrix multiply.
    output_parallel = self.apply(input_, bias)
    if self.base_layer.gather_output and self.tp_size > 1:
        # All-gather across the partitions.
        output = tensor_model_parallel_all_gather(output_parallel)
    else:
        output = output_parallel

    if not self.base_layer.return_bias:
        return output

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
    return output, output_bias

ColumnParallelLinearWithShardedLoRA

Bases: ColumnParallelLinearWithLoRA

Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
    """
    Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
    # their `lora_a` and `lora_b` have different sharding patterns. After
    # completing the `lora_a` GEMM , a gather operation is performed.
    # Therefore, the sharding of `lora_a` only needs to correspond with the
    # gather operation.
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        lora_a = lora_a[start_idx : start_idx + shard_size, :]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

FusedMoE3DWithLoRA

Bases: FusedMoEWithLoRA

Source code in vllm/lora/layers/fused_moe.py
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
    def __init__(self, base_layer):
        super().__init__(base_layer)
        self._w13_slices = 1

    def _create_lora_b_weights(self, max_loras, lora_config):
        self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition * 2,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        assert isinstance(model_config, PretrainedConfig)
        self._base_model = model_config.architectures[0]
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        # HACK: Currently, only GPT-OSS is in interleaved order
        if self._base_model == "GptOssForCausalLM":
            # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
            # in the interleaved order, and corresponding LoRA need to be processed.
            w1_lora_b = w13_lora_b[:, ::2, :]
            w3_lora_b = w13_lora_b[:, 1::2, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
                1, 2
            )
        else:
            slice_size = w13_lora_b.shape[1] // 2
            w1_lora_b = w13_lora_b[:, :slice_size, :]
            w3_lora_b = w13_lora_b[:, slice_size:, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
        assert len(lora_a) == len(lora_b) == 2

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        w13_lora_a, w2_lora_a = lora_a
        w13_lora_b, w2_lora_b = lora_b

        sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
        sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
        ].copy_(sliced_w13_lora_a, non_blocking=True)
        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
        ].copy_(sliced_w13_lora_b, non_blocking=True)
        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    @property
    def w13_input_size(self):
        """
        Full size
        """
        return self.w13_lora_a_stacked[0].shape[-1]

    @property
    def w13_output_size(self):
        """
        Full size
        """
        return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size

    @property
    def w2_input_size(self):
        """
        Full size
        """
        return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size

    @property
    def w2_output_size(self):
        """
        Full size
        """
        return self.base_layer.hidden_size

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        # source_layer is FusedMoE or SharedFusedMoE
        return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

w13_input_size property

w13_input_size

Full size

w13_output_size property

w13_output_size

Full size

w2_input_size property

w2_input_size

Full size

w2_output_size property

w2_output_size

Full size

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    # source_layer is FusedMoE or SharedFusedMoE
    return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""

    assert isinstance(model_config, PretrainedConfig)
    self._base_model = model_config.architectures[0]
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)
    assert len(lora_a) == len(lora_b) == 2

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    w13_lora_a, w2_lora_a = lora_a
    w13_lora_b, w2_lora_b = lora_b

    sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
    sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
    ].copy_(sliced_w13_lora_a, non_blocking=True)
    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
    ].copy_(sliced_w13_lora_b, non_blocking=True)
    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)

FusedMoEWithLoRA

Bases: BaseLayerWithLoRA

Source code in vllm/lora/layers/fused_moe.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: FusedMoE) -> None:
        super().__init__()
        self.base_layer = base_layer

        assert not self.base_layer.use_ep, (
            "EP support for Fused MoE LoRA is not implemented yet."
        )
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.device = _get_lora_device(base_layer)
        # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
        # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
        self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
        self._inject_lora_into_fused_moe()

    def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
        normalized_config = {}
        for key, value in config.items():
            if key.islower():
                if key.startswith("block_"):
                    normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
                else:
                    normalized_key = key.upper()
            else:
                normalized_key = key
            normalized_config[normalized_key] = value
        return normalized_config

    def _get_lora_moe_configs(
        self,
        op_prefix: str,
        num_loras: int,
        rank: int,
        num_slices: int,
        M: int,
        layer: FusedMoE,
        top_k: int,
        config_dtype: str,
    ):
        if envs.VLLM_TUNED_CONFIG_FOLDER:
            hidden_size = layer.hidden_size
            intermediate_size = layer.intermediate_size_per_partition
            shrink_config = get_lora_op_configs(
                op_type=f"fused_moe_lora_{op_prefix}_shrink",
                max_loras=num_loras,
                batch=M,
                hidden_size=hidden_size,
                rank=rank,
                num_slices=num_slices,
                moe_intermediate_size=intermediate_size,
            )
            expand_config = get_lora_op_configs(
                op_type=f"fused_moe_lora_{op_prefix}_expand",
                max_loras=num_loras,
                batch=M,
                hidden_size=hidden_size,  # lora_a_stacked.shape[-1],
                rank=rank,
                num_slices=num_slices,
                moe_intermediate_size=intermediate_size,  # lora_b_stacked.shape[-2],
            )
        else:  # fall back to the default config
            get_config_func = functools.partial(
                try_get_optimal_moe_lora_config,
                w1_shape=layer.w13_weight.size(),
                w2_shape=layer.w2_weight.size(),
                rank=rank,
                top_k=top_k,
                dtype=config_dtype,
                M=M,
                block_shape=layer.quant_method.moe_quant_config.block_shape,
            )
            shrink_config = get_config_func(
                op_type=f"fused_moe_lora_{op_prefix}_shrink"
            )
            expand_config = get_config_func(
                op_type=f"fused_moe_lora_{op_prefix}_expand"
            )
        shrink_config = self._normalize_keys(shrink_config)
        expand_config = self._normalize_keys(expand_config)
        return shrink_config, expand_config

    def _inject_lora_into_fused_moe(self):
        moe_state_dict = {}
        top_k = self.base_layer.top_k

        self.base_layer.ensure_moe_quant_config_init()
        quant_config = self.base_layer.quant_method.moe_quant_config

        if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
            # Use the existing modular kernel from the quant method
            m_fused_moe_fn = self.base_layer.quant_method.moe_mk
        else:
            # Create a new modular kernel via select_gemm_impl
            prepare_finalize = MoEPrepareAndFinalizeNoEP()
            m_fused_moe_fn = FusedMoEModularKernel(
                prepare_finalize,
                self.base_layer.quant_method.select_gemm_impl(
                    prepare_finalize, self.base_layer
                ),
                self.base_layer.shared_experts,
            )

        if quant_config.use_mxfp4_w4a16:
            assert isinstance(
                m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
            )
        else:
            assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts)

        def fwd_decorator(layer, func):
            def wrapper(*args, **kwargs):
                moe_state_dict["hidden_states"] = kwargs["hidden_states"]
                moe_state_dict["topk_ids"] = kwargs["topk_ids"]
                moe_state_dict["topk_weights"] = kwargs["topk_weights"]
                moe_state_dict["expert_map"] = kwargs["expert_map"]
                moe_state_dict["apply_router_weight_on_input"] = kwargs[
                    "apply_router_weight_on_input"
                ]
                result = func(*args, **kwargs)
                return result

            return wrapper

        def act_decorator(layer, func):
            def wrapper(*args, **kwargs):
                _, output, input = args

                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]
                curr_topk_ids = moe_state_dict["topk_ids"]

                expert_map = moe_state_dict["expert_map"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)
                max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
                shrink_config, expand_config = self._get_lora_moe_configs(
                    op_prefix="w13",
                    num_loras=self.max_loras,
                    rank=max_lora_rank,
                    num_slices=self._w13_slices,
                    M=M,
                    layer=layer,
                    top_k=top_k,
                    config_dtype=config_dtype,
                )

                # SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
                # activates only a small fraction of total experts * loras.
                SPARSITY_FACTOR = 8
                naive_block_assignment = (
                    expert_map is None
                    and num_tokens * top_k * SPARSITY_FACTOR
                    <= self.base_layer.local_num_experts * self.max_loras
                )

                # get the block size of m from customized config or default config
                (
                    token_lora_mapping,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                ) = self.punica_wrapper.moe_lora_align_block_size(
                    curr_topk_ids,
                    num_tokens,
                    shrink_config["BLOCK_SIZE_M"],
                    self.base_layer.local_num_experts,
                    self.max_loras,
                    self.adapter_enabled,
                    expert_map,
                    naive_block_assignment,
                )

                moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
                moe_state_dict["expert_ids_lora"] = expert_ids_lora
                moe_state_dict["num_tokens_post_padded_lora"] = (
                    num_tokens_post_padded_lora
                )
                moe_state_dict["token_lora_mapping"] = token_lora_mapping

                if sorted_token_ids_lora is not None:
                    expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
                    sorted_token_ids_lora = sorted_token_ids_lora.view(
                        self.max_loras, -1
                    )
                #

                self.punica_wrapper.add_lora_fused_moe(
                    input.view(-1, top_k, input.shape[-1]),
                    hidden_states,
                    self.w13_lora_a_stacked,
                    self.w13_lora_b_stacked,
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
                    shrink_config,  ## pass the shrink config
                    expand_config,  ## pass the expand config
                    self.adapter_enabled,
                    fully_sharded=self.fully_sharded,
                    token_lora_mapping=token_lora_mapping,
                )

                result = func(*args, **kwargs)

                moe_state_dict["intermediate_cache2"] = output
                return result

            return wrapper

        def moe_sum_decorator(layer, func):
            def wrapper(*args, **kwargs):
                hidden_states = moe_state_dict["hidden_states"]
                topk_weights = moe_state_dict["topk_weights"]

                config_dtype = _get_config_dtype_str(
                    dtype=hidden_states.dtype,
                    use_fp8_w8a8=False,
                    use_int8_w8a16=False,
                    use_int4_w4a16=False,
                )
                CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
                num_tokens = hidden_states.size(0)
                M = min(num_tokens, CHUNK_SIZE)
                max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
                shrink_config, expand_config = self._get_lora_moe_configs(
                    op_prefix="w2",
                    num_loras=self.max_loras,
                    rank=max_lora_rank,
                    num_slices=1,
                    M=M,
                    layer=layer,
                    top_k=top_k,
                    config_dtype=config_dtype,
                )

                sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
                expert_ids_lora = moe_state_dict["expert_ids_lora"]
                num_tokens_post_padded_lora = moe_state_dict[
                    "num_tokens_post_padded_lora"
                ]
                token_lora_mapping = moe_state_dict.get("token_lora_mapping")

                if sorted_token_ids_lora is not None:
                    expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
                    sorted_token_ids_lora = sorted_token_ids_lora.view(
                        self.max_loras, -1
                    )
                intermediate_cache2 = moe_state_dict["intermediate_cache2"]
                intermediate_cache3 = args[0]

                shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)

                self.punica_wrapper.add_lora_fused_moe(
                    intermediate_cache3,
                    intermediate_cache2,
                    self.w2_lora_a_stacked,
                    self.w2_lora_b_stacked,
                    topk_weights,
                    sorted_token_ids_lora,
                    expert_ids_lora,
                    num_tokens_post_padded_lora,
                    max_lora_rank,
                    top_k,
                    shrink_config,  ## pass the shrink config
                    expand_config,  ## pass the expand config
                    self.adapter_enabled,
                    True,
                    fully_sharded=self.fully_sharded,
                    offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
                    token_lora_mapping=token_lora_mapping,
                )

                result = func(*args, **kwargs)
                return result

            return wrapper

        fused_experts = m_fused_moe_fn.fused_experts

        m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
        fused_experts.activation = act_decorator(
            self.base_layer, fused_experts.activation
        )
        fused_experts.moe_sum = moe_sum_decorator(
            self.base_layer, fused_experts.moe_sum
        )
        self.base_layer.quant_method = FusedMoEModularMethod(
            self.base_layer.quant_method, m_fused_moe_fn
        )

    def _create_lora_a_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
    ):
        self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank
                    if not self.fully_sharded
                    else divide(lora_config.max_lora_rank, self.tp_size),
                    self.base_layer.hidden_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank,
                    self.base_layer.intermediate_size_per_partition,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
        self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)
        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
        # TODO Optimize this section
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
            for experts_id in range(self.base_layer.local_num_experts):
                # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
                # For non-gated MoE: up_proj (w1), down_proj (w2)
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[0][lora_id][experts_id]
                )
                self.lora_a_stacked.append(
                    self.w2_lora_a_stacked[0][lora_id][experts_id]
                )

                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[0][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w2_lora_b_stacked[0][lora_id][experts_id]
                )

                # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
                if self._w13_slices == 2:
                    self.lora_a_stacked.append(
                        self.w13_lora_a_stacked[1][lora_id][experts_id]
                    )
                    self.lora_b_stacked.append(
                        self.w13_lora_b_stacked[1][lora_id][experts_id]
                    )

    def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w13_lora_a

        # w13_lora_a shape (num_experts,rank,input_size)
        current_lora_rank = w13_lora_a.shape[1]
        assert current_lora_rank % self.tp_size == 0
        # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
        shard_size = self.w13_lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        return w13_lora_a[:, start_idx:end_idx, :]

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w13_lora_b[:, start_idx:end_idx, :]

    def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1:
            return w2_lora_a
        # w2_lora_a shape (num_experts,rank,input_size)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_a[:, :, start_idx:end_idx]

    def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w2_lora_b
        # Based on S-LoRA, we slice W2 B along the hidden_size dim.
        # w2_lora_b shape (num_experts,output_size,rank)
        shard_size = self.w2_lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_b[:, start_idx:end_idx, :]

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        for pos in range(self._w13_slices):
            self.w13_lora_a_stacked[pos][index] = 0
            self.w13_lora_b_stacked[pos][index] = 0

        self.w2_lora_a_stacked[0][index] = 0
        self.w2_lora_b_stacked[0][index] = 0
        self.adapter_enabled[index] = 0

    #

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        num_experts = self.w13_lora_a_stacked[0].shape[1]

        w1_lora_a, w2_lora_a, w3_lora_a = lora_a
        w1_lora_b, w2_lora_b, w3_lora_b = lora_b
        assert (
            num_experts
            == w1_lora_a.shape[0]
            == w2_lora_a.shape[0]
            == w3_lora_a.shape[0]
        )

        slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
        slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
        ].copy_(slliced_w1_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
        ].copy_(slliced_w1_lora_b, non_blocking=True)

        # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
        if self._w13_slices == 2:
            slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
            slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

            self.w13_lora_a_stacked[1][
                index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
            ].copy_(slliced_w3_lora_a, non_blocking=True)

            self.w13_lora_b_stacked[1][
                index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
            ].copy_(slliced_w3_lora_b, non_blocking=True)

        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
        return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)

    @property
    def _shared_experts(self):
        return self.base_layer._shared_experts

    @property
    def quant_method(self):
        return self.base_layer.quant_method

    @property
    def is_internal_router(self) -> bool:
        return self.base_layer.is_internal_router

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""

        # source_layer is FusedMoE or SharedFusedMoE
        return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2

_slice_w13_a

_slice_w13_a(w13_lora_a: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w13_lora_a

    # w13_lora_a shape (num_experts,rank,input_size)
    current_lora_rank = w13_lora_a.shape[1]
    assert current_lora_rank % self.tp_size == 0
    # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
    shard_size = self.w13_lora_a_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size
    return w13_lora_a[:, start_idx:end_idx, :]

_slice_w2_a

_slice_w2_a(w2_lora_a: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1:
        return w2_lora_a
    # w2_lora_a shape (num_experts,rank,input_size)
    shard_size = self.base_layer.intermediate_size_per_partition
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_a[:, :, start_idx:end_idx]

_slice_w2_b

_slice_w2_b(w2_lora_b: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w2_lora_b
    # Based on S-LoRA, we slice W2 B along the hidden_size dim.
    # w2_lora_b shape (num_experts,output_size,rank)
    shard_size = self.w2_lora_b_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_b[:, start_idx:end_idx, :]

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""

    # source_layer is FusedMoE or SharedFusedMoE
    return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)
    # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
    # to create a dummy LoRA weights.
    # TODO Optimize this section
    self.lora_a_stacked = []
    self.lora_b_stacked = []
    for lora_id in range(max_loras):
        for experts_id in range(self.base_layer.local_num_experts):
            # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
            # For non-gated MoE: up_proj (w1), down_proj (w2)
            self.lora_a_stacked.append(
                self.w13_lora_a_stacked[0][lora_id][experts_id]
            )
            self.lora_a_stacked.append(
                self.w2_lora_a_stacked[0][lora_id][experts_id]
            )

            self.lora_b_stacked.append(
                self.w13_lora_b_stacked[0][lora_id][experts_id]
            )
            self.lora_b_stacked.append(
                self.w2_lora_b_stacked[0][lora_id][experts_id]
            )

            # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
            if self._w13_slices == 2:
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[1][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[1][lora_id][experts_id]
                )

reset_lora

reset_lora(index: int)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/fused_moe.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    for pos in range(self._w13_slices):
        self.w13_lora_a_stacked[pos][index] = 0
        self.w13_lora_b_stacked[pos][index] = 0

    self.w2_lora_a_stacked[0][index] = 0
    self.w2_lora_b_stacked[0][index] = 0
    self.adapter_enabled[index] = 0

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    num_experts = self.w13_lora_a_stacked[0].shape[1]

    w1_lora_a, w2_lora_a, w3_lora_a = lora_a
    w1_lora_b, w2_lora_b, w3_lora_b = lora_b
    assert (
        num_experts
        == w1_lora_a.shape[0]
        == w2_lora_a.shape[0]
        == w3_lora_a.shape[0]
    )

    slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
    slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
    ].copy_(slliced_w1_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
    ].copy_(slliced_w1_lora_b, non_blocking=True)

    # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
    if self._w13_slices == 2:
        slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
        slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

        self.w13_lora_a_stacked[1][
            index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
        ].copy_(slliced_w3_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[1][
            index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
        ].copy_(slliced_w3_lora_b, non_blocking=True)

    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)

LogitsProcessorWithLoRA

Bases: BaseLayerWithLoRA

LoRA wrapper for LogitsProcessor, with extra logic to handle the application of the LoRA adapter and added LoRA vocabulary.

Parameters:

Name Type Description Default
base_layer LogitsProcessor

LogitsProcessor layer

required
hidden_size int

hidden size of the model

required
dtype dtype

data type of the model

required
device device

device of the model

required
sharded_to_full_mapping list[int] | None

index mapping from sharded vocab to full vocab received from base_layer.get_sharded_to_full_mapping(). If None, no reindexing will be done.

required
Source code in vllm/lora/layers/logits_processor.py
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """

    def __init__(
        self,
        base_layer: LogitsProcessor,
        hidden_size: int,
        dtype: torch.dtype,
        device: torch.device,
        sharded_to_full_mapping: list[int] | None,
    ) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping

    @property
    def logits_as_input(self):
        return self.base_layer.logits_as_input

    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

    @property
    def scale(self):
        return self.base_layer.scale

    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

    @property
    def use_all_gather(self):
        return self.base_layer.use_all_gather

    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        # TODO: Verify if this condition can be further relaxed
        if 32000 < self.base_layer.vocab_size > 257024:
            raise ValueError(
                "When using LoRA, vocab size must be 32000 >= vocab_size <= 257024"
            )
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping, device=self.device, dtype=torch.long
            )
        else:
            self.sharded_to_full_mapping_gpu = None

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        assert isinstance(lora_a, torch.Tensor)
        assert isinstance(lora_b, torch.Tensor)
        self.reset_lora(index)
        self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
            lora_a, non_blocking=True
        )
        self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
            lora_b, non_blocking=True
        )

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
        lm_head: VocabParallelEmbedding,
        embedding_bias: torch.Tensor | None = None,
    ) -> torch.Tensor | None:
        # Get the logits for the next tokens.
        if hasattr(lm_head, "base_layer"):
            actual_lm_head = lm_head.base_layer
        else:
            actual_lm_head = lm_head
        logits = actual_lm_head.quant_method.apply(actual_lm_head, hidden_states)
        if embedding_bias is not None:
            logits += embedding_bias

        # Gather logits for TP
        logits = self.base_layer._gather_logits(logits)

        if logits is None:
            return None

        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

        lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
            logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
        )

        if not current_platform.can_update_inplace():
            logits = lora_output

        # Remove paddings in vocab (if any).
        logits = logits[:, : self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # Special handling for the LogitsProcessor.
        return False

MergedColumnParallelLinearVariableSliceWithLoRA

Bases: MergedColumnParallelLinearWithLoRA

MergedColumnParallelLinear with variable number of slices (3+).

This handles cases where the checkpoint has a single weight for the whole module (not split into slices), but the layer itself has multiple slices.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearVariableSliceWithLoRA(
    MergedColumnParallelLinearWithLoRA
):
    """MergedColumnParallelLinear with variable number of slices (3+).

    This handles cases where the checkpoint has a single weight for the whole
    module (not split into slices), but the layer itself has multiple slices.
    """

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # Support MergedColumnParallelLinear with 3 or more slices
        # (2 slices are handled by MergedColumnParallelLinearWithLoRA)
        if type(source_layer) is not MergedColumnParallelLinear:
            return False

        # If packed_modules_list has 3+ items, use this class
        if len(packed_modules_list) >= 3:
            return True

        # If packed_modules_list has exactly 2 items, let
        # MergedColumnParallelLinearWithLoRA handle it
        if len(packed_modules_list) == 2:
            return False

        # If packed_modules_list is empty or has 1 item,
        # check the layer's output_sizes.
        # This handles cases where the checkpoint has a single weight
        # but the layer has multiple slices (3+)
        return (
            hasattr(source_layer, "output_sizes")
            and len(source_layer.output_sizes) >= 3
        )

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Override to handle single tensor weights
        that need to be split into slices."""
        self.reset_lora(index)

        # Handle case where checkpoint has single tensor weights
        # lora_a shape: (rank, input_size) - same for all slices, duplicate it
        if isinstance(lora_a, torch.Tensor):
            lora_a = [lora_a] * self.n_slices

        # lora_b shape: (total_output_size, rank) -
        # split along dim 0 based on output_sizes
        if isinstance(lora_b, torch.Tensor):
            output_sizes = self.base_layer.output_sizes
            lora_b_list = []
            start_idx = 0
            for output_size in output_sizes:
                end_idx = start_idx + output_size
                lora_b_list.append(lora_b[start_idx:end_idx, :])
                start_idx = end_idx
            lora_b = lora_b_list

        # Now call parent's set_lora which expects lists
        super().set_lora(index, lora_a, lora_b)

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Override to handle single tensor weights that need to be split into slices.

Source code in vllm/lora/layers/column_parallel_linear.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Override to handle single tensor weights
    that need to be split into slices."""
    self.reset_lora(index)

    # Handle case where checkpoint has single tensor weights
    # lora_a shape: (rank, input_size) - same for all slices, duplicate it
    if isinstance(lora_a, torch.Tensor):
        lora_a = [lora_a] * self.n_slices

    # lora_b shape: (total_output_size, rank) -
    # split along dim 0 based on output_sizes
    if isinstance(lora_b, torch.Tensor):
        output_sizes = self.base_layer.output_sizes
        lora_b_list = []
        start_idx = 0
        for output_size in output_sizes:
            end_idx = start_idx + output_size
            lora_b_list.append(lora_b[start_idx:end_idx, :])
            start_idx = end_idx
        lora_b = lora_b_list

    # Now call parent's set_lora which expects lists
    super().set_lora(index, lora_a, lora_b)

MergedColumnParallelLinearWithLoRA

Bases: ColumnParallelLinearWithLoRA

ColumnParallelLinear layer that is composed of 2 sublayers (slices) packed together (e.g. gate_proj + up_proj -> gate_up_proj).

This means we have 2 LoRAs, each applied to one half of the layer.

Both slices must have the same size.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (e.g. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(
        self, base_layer: MergedColumnParallelLinear | QKVParallelLinear
    ) -> None:
        super().__init__(base_layer)
        # There are two LoRA layers
        # the output_sizes in MergedColumnParallelLinear is not sharded by tp
        # we need to divide it by the tp_size to get correct slices size
        output_sizes = self.base_layer.output_sizes
        self.output_slices = tuple(
            divide(output_size, self.tp_size) for output_size in output_sizes
        )
        self.n_slices = len(self.output_slices)
        self.output_ids = (self.tp_rank,) * self.n_slices

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """
        The main reason for overriding this function is to enhance  code
        maintainability.
        """
        self.lora_config = lora_config

        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank
            if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size)
        )

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_a_output_size_per_partition,
                self.input_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self.n_slices)
        )
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                output_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for output_size in self.output_slices
        )

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        return lora_a

    def slice_lora_b(
        self, lora_b: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        sliced_lora_b = [None] * self.n_slices
        for i, (shard_id, shard_size) in enumerate(
            zip(self.output_ids, self.output_slices)
        ):
            if (lora_b_i := lora_b[i]) is not None:
                sliced_lora_b[i] = lora_b_i[
                    shard_size * shard_id : shard_size * (shard_id + 1), :
                ]
        return sliced_lora_b

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

        for i in range(self.n_slices):
            if (lora_a_i := lora_a[i]) is not None:
                self.lora_a_stacked[i][
                    index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1]
                ].copy_(lora_a_i, non_blocking=True)
            if (lora_b_i := lora_b[i]) is not None:
                self.lora_b_stacked[i][
                    index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
                ].copy_(lora_b_i, non_blocking=True)

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return (
            type(source_layer) is MergedColumnParallelLinear
            and len(packed_modules_list) == 2
        )

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

The main reason for overriding this function is to enhance code maintainability.

Source code in vllm/lora/layers/column_parallel_linear.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """
    The main reason for overriding this function is to enhance  code
    maintainability.
    """
    self.lora_config = lora_config

    lora_a_output_size_per_partition = (
        lora_config.max_lora_rank
        if not lora_config.fully_sharded_loras
        else divide(lora_config.max_lora_rank, self.tp_size)
    )

    self.lora_a_stacked = tuple(
        torch.zeros(
            max_loras,
            1,
            lora_a_output_size_per_partition,
            self.input_size,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        for _ in range(self.n_slices)
    )
    self.lora_b_stacked = tuple(
        torch.zeros(
            max_loras,
            1,
            output_size,
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        for output_size in self.output_slices
    )

MergedColumnParallelLinearWithShardedLoRA

Bases: MergedColumnParallelLinearWithLoRA

Differs from MergedColumnParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA):
    """
    Differs from MergedColumnParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        # NOTE: lora_a contains 2 subloras, and each sublora could be None.
        output_shard_size = self.lora_a_stacked[0].shape[2]
        output_start_idx = self.tp_rank * output_shard_size
        lora_a = [
            lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
            if lora_a[0] is not None
            else None,
            lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
            if lora_a[1] is not None
            else None,
        ]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

MergedQKVParallelLinearWithLoRA

Bases: MergedColumnParallelLinearWithLoRA

MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj).

This means we have 3 LoRAs, each applied to one slice of the layer.

Q slice may have different shape than K and V slices (which both have the same shape).

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
    """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)

        self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
        self.kv_proj_shard_size = (
            self.base_layer.num_kv_heads * self.base_layer.head_size
        )
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas

        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
        self.output_ids = (
            self.q_shard_id,
            self.kv_shard_id,
            self.kv_shard_id,
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """
        The main reason for overloading this function is to handle inconsistent
        weight dimensions in qkv lora.
        """
        super().create_lora_weights(max_loras, lora_config, model_config)

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

The main reason for overloading this function is to handle inconsistent weight dimensions in qkv lora.

Source code in vllm/lora/layers/column_parallel_linear.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """
    The main reason for overloading this function is to handle inconsistent
    weight dimensions in qkv lora.
    """
    super().create_lora_weights(max_loras, lora_config, model_config)

MergedQKVParallelLinearWithShardedLoRA

Bases: MergedQKVParallelLinearWithLoRA

Differs from MergedQKVParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
    """
    Differs from MergedQKVParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        # NOTE: lora_a contains 3 subloras, and each sublora could be None.
        shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
        start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
        lora_a = [
            lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
            if lora_a[0] is not None
            else None,
            lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
            if lora_a[1] is not None
            else None,
            lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
            if lora_a[2] is not None
            else None,
        ]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

QKVParallelLinearWithLoRA

Bases: ColumnParallelLinearWithLoRA

ColumnParallelLinear layer that is specifically designed for qkv_proj. Certain models, such as chatglm3 and baichuan-7b, only contains a single LoRA within their qkv_proj layer.

During inference with Tensor Parallel, the weights of lora_b must be accurately partitioned according to the respective ranks.

Q slice may have different shape than K and V slices (which both have the same shape).

Source code in vllm/lora/layers/column_parallel_linear.py
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """
    ColumnParallelLinear layer that is specifically designed for
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
    only contains a single LoRA within their qkv_proj layer.

    During inference with Tensor Parallel, the weights of lora_b
    must be accurately partitioned according to the respective ranks.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.q_proj_total_size = (
            self.base_layer.total_num_heads * self.base_layer.head_size
        )
        self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
        self.kv_proj_shard_size = (
            self.base_layer.num_kv_heads * self.base_layer.head_size
        )
        self.kv_proj_total_size = (
            self.base_layer.total_num_kv_heads * self.base_layer.head_size
        )
        # There is only one LoRA layer
        self.n_slices = 1

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[
            self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
            * (self.q_shard_id + 1),
            :,
        ]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[
            k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
            + self.kv_proj_shard_size * (self.kv_shard_id + 1),
            :,
        ]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[
            v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
            + self.kv_proj_shard_size * (self.kv_shard_id + 1),
            :,
        ]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
        return lora_b

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1

QKVParallelLinearWithShardedLoRA

Bases: QKVParallelLinearWithLoRA

Differs from QKVParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
    """
    Differs from QKVParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        lora_a = lora_a[start_idx : start_idx + shard_size, :]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

ReplicatedLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

Source code in vllm/lora/layers/replicated_linear.py
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__(
            base_layer,
        )
        # To ensure interface compatibility, set to 1 always.
        self.output_size = self.base_layer.output_size
        self.n_slices = 1

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None

        if not self.base_layer.return_bias:
            return output

        return output, output_bias

    # ReplicatedLinear should always be replaced, regardless of the fully
    # sharded LoRAs setting, because it is, by definition, copied per GPU.
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is ReplicatedLinear

    def slice_lora_a(
        self, lora_a: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora a if splitting for tensor parallelism."""
        return lora_a

    def slice_lora_b(
        self, lora_b: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora b if splitting with tensor parallelism."""
        return lora_b

forward

forward(
    input_: Tensor,
) -> Tensor | tuple[Tensor, Tensor | None]

Forward of ReplicatedLinearWithLoRA

Parameters:

Name Type Description Default
input_ Tensor

Tensor whose last dimension is input_size.

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor | None]
  • output
Tensor | tuple[Tensor, Tensor | None]
  • bias
Source code in vllm/lora/layers/replicated_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of ReplicatedLinearWithLoRA

    Args:
        input_: Tensor whose last dimension is `input_size`.

    Returns:
        - output
        - bias
    """
    bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

    # Matrix multiply.
    output = self.apply(input_, bias)

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None

    if not self.base_layer.return_bias:
        return output

    return output, output_bias

slice_lora_a

slice_lora_a(
    lora_a: Tensor | list[Tensor | None],
) -> Tensor | list[Tensor | None]

Slice lora a if splitting for tensor parallelism.

Source code in vllm/lora/layers/replicated_linear.py
def slice_lora_a(
    self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora a if splitting for tensor parallelism."""
    return lora_a

slice_lora_b

slice_lora_b(
    lora_b: Tensor | list[Tensor | None],
) -> Tensor | list[Tensor | None]

Slice lora b if splitting with tensor parallelism.

Source code in vllm/lora/layers/replicated_linear.py
def slice_lora_b(
    self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora b if splitting with tensor parallelism."""
    return lora_b

RowParallelLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

Source code in vllm/lora/layers/row_parallel_linear.py
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__(base_layer)

        # reset input_size
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        # There is only one LoRA layer.
        self.n_slices = 1

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.input_size
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_a = lora_a[:, start_idx:end_idx]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            splitted_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size
            )
            input_parallel = splitted_input[self.tp_rank].contiguous()

        # Matrix multiply.
        bias_ = (
            None
            if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
            else self.base_layer.bias
        )
        output_parallel = self.apply(input_parallel, bias_)
        if self.base_layer.reduce_results and self.tp_size > 1:
            output = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output = output_parallel

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
        if not self.base_layer.return_bias:
            return output

        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is RowParallelLinear

forward

forward(
    input_: Tensor,
) -> Tensor | tuple[Tensor, Tensor | None]

Forward of RowParallelLinear

Parameters:

Name Type Description Default
input_ Tensor

tensor whose last dimension is input_size. If input_is_parallel is set, then the last dimension is input_size // tp_size.

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor | None]
  • output
Tensor | tuple[Tensor, Tensor | None]
  • bias
Source code in vllm/lora/layers/row_parallel_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of RowParallelLinear

    Args:
        input_: tensor whose last dimension is `input_size`. If
                `input_is_parallel` is set, then the last dimension
                is `input_size // tp_size`.

    Returns:
        - output
        - bias
    """
    # set up backprop all-reduce.
    if self.base_layer.input_is_parallel:
        input_parallel = input_
    else:
        # TODO: simplify code below
        splitted_input = split_tensor_along_last_dim(
            input_, num_partitions=self.tp_size
        )
        input_parallel = splitted_input[self.tp_rank].contiguous()

    # Matrix multiply.
    bias_ = (
        None
        if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
        else self.base_layer.bias
    )
    output_parallel = self.apply(input_parallel, bias_)
    if self.base_layer.reduce_results and self.tp_size > 1:
        output = tensor_model_parallel_all_reduce(output_parallel)
    else:
        output = output_parallel

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
    if not self.base_layer.return_bias:
        return output

    return output, output_bias

RowParallelLinearWithShardedLoRA

Bases: RowParallelLinearWithLoRA

Differs from RowParallelLinearWithLoRA by slicing the LoRA B's also.

Based on S-LoRA, slicing happens along the output dim. This yields a combined partial sum from the row parallel base layer and column partitioned output from the LoRA.

Source code in vllm/lora/layers/row_parallel_linear.py
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
    """
    Differs from RowParallelLinearWithLoRA by slicing the
    LoRA B's also.

    Based on S-LoRA, slicing happens along the output dim.
    This yields a combined partial sum from the row parallel base
    layer and column partitioned output from the LoRA.
    """

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = lora_b[start_idx:end_idx, :]
        return lora_b

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

        x = x.view(-1, x.shape[-1])
        output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
        buffer = torch.zeros(
            (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
            dtype=torch.float32,
            device=x.device,
        )

        shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
            buffer, x, self.lora_a_stacked, 1.0
        )
        if not current_platform.can_update_inplace():
            buffer = shrunk_buffer
        if self.tp_size > 1:
            buffer = tensor_model_parallel_all_reduce(buffer)

        # following S-LoRA, allows the fusing of all_gather and all_reduce
        # by adding the column partitioned lora output to a slice of output
        # tensor, which is a partial sum due to row parallel. All that
        # remains is a standard all_reduce. User should be aware though that
        # the output is not the same as a normal row_parallel, it should be
        # reduced before being used
        # NOTE offset are based on the rank.
        shard_size = self.lora_b_stacked[0].shape[2]
        offset_start = self.tp_rank * shard_size
        lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
            output,
            buffer,
            self.lora_b_stacked,
            self.output_slices,
            offset_start=offset_start,
            add_input=True,
        )

        if not current_platform.can_update_inplace():
            output = lora_output

        output = output.view(*out_orig_shape)
        return output

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )