Skip to content

vllm.model_executor.models.step3_vl

Step3VLImageEmbeddingInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • f: Image feature size
  • h: Hidden size (must match the hidden size of language model backbone)
Source code in vllm/model_executor/models/step3_vl.py
class Step3VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
    """

    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]

Step3VLImagePixelInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • c: Number of channels (3)
  • h: Height
  • w: Width
  • bnp: Batch size * number of images * number of patches
  • hp: Height of patch
  • wp: Width of patch
Source code in vllm/model_executor/models/step3_vl.py
class Step3VLImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
        - bnp: Batch size * number of images * number of patches
        - hp: Height of patch
        - wp: Width of patch
    """

    type: Literal["pixel_values"]
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
    patch_pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "hp", "wp")]
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]

Step3VisionAttention

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper

Source code in vllm/model_executor/models/step3_vl.py
class Step3VisionAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        config,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.total_num_heads

        self.scale = self.head_dim**-0.5

        use_data_parallel = is_vit_use_data_parallel()
        tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size

        self.q_size = self.num_heads * self.head_dim

        self.qkv_proj = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
            disable_tp=use_data_parallel,
        )
        self.out_proj = RowParallelLinear(
            self.embed_dim,
            self.embed_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
            disable_tp=use_data_parallel,
        )

        # Use unified MMEncoderAttention with automatic backend selection
        self.attn = MMEncoderAttention(
            self.num_heads,
            self.head_dim,
            self.scale,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, _ = hidden_states.size()

        # get query proj
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)

        # Use unified MMEncoderAttention with automatic backend selection
        attn_output = self.attn(q, k, v)

        attn_output, _ = self.out_proj(attn_output)

        return attn_output

forward

forward(hidden_states: Tensor)

Input shape: Batch x Time x Channel

Source code in vllm/model_executor/models/step3_vl.py
def forward(
    self,
    hidden_states: torch.Tensor,
):
    """Input shape: Batch x Time x Channel"""
    bsz, tgt_len, _ = hidden_states.size()

    # get query proj
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.chunk(chunks=3, dim=-1)

    # Use unified MMEncoderAttention with automatic backend selection
    attn_output = self.attn(q, k, v)

    attn_output, _ = self.out_proj(attn_output)

    return attn_output