Skip to content

vllm.model_executor.models.audioflamingo3

AudioFlamingo3EmbeddingInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size
  • naf: Number of audio features
  • hs: Hidden size (must match the hidden size of language model backbone)
Source code in vllm/model_executor/models/audioflamingo3.py
class AudioFlamingo3EmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size
        - naf: Number of audio features
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """

    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
        list[torch.Tensor],
        TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
    ]

AudioFlamingo3Encoder

Bases: Qwen2AudioEncoder

Source code in vllm/model_executor/models/audioflamingo3.py
class AudioFlamingo3Encoder(Qwen2AudioEncoder):
    def __init__(
        self,
        config: PretrainedConfig,
    ):
        super().__init__(config)
        self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
        # self.layer_norm is already initialized in super().__init__
        # Keep a dummy freqs parameter for MusicFlamingo checkpoints.
        self.pos_emb = nn.Module()
        freqs = torch.empty(getattr(config, "num_mel_bins", 128))
        self.pos_emb.register_parameter(
            "freqs", nn.Parameter(freqs, requires_grad=False)
        )

    def forward(
        self,
        input_features: torch.Tensor | list[torch.Tensor],
        attention_mask: torch.Tensor = None,
    ):
        # input_features: (batch, num_mel_bins, seq_len)
        if isinstance(input_features, list):
            input_features = torch.stack(input_features)

        hidden_states = nn.functional.gelu(self.conv1(input_features))
        hidden_states = nn.functional.gelu(self.conv2(hidden_states))
        hidden_states = hidden_states.transpose(-1, -2)
        hidden_states = (
            hidden_states + self.embed_positions.weight[: hidden_states.size(-2), :]
        ).to(hidden_states.dtype)

        for layer in self.layers:
            # Qwen2AudioEncoderLayer expects layer_head_mask as third arg.
            layer_outputs = layer(hidden_states, attention_mask, None)
            hidden_states = layer_outputs[0]

        # AvgPool (time/2) + LayerNorm
        # hidden_states: (batch, seq_len, hidden_size)
        hidden_states = hidden_states.permute(0, 2, 1)  # (batch, hidden_size, seq_len)
        hidden_states = self.avg_pooler(hidden_states)
        hidden_states = hidden_states.permute(
            0, 2, 1
        )  # (batch, seq_len/2, hidden_size)
        hidden_states = self.layer_norm(hidden_states)

        return hidden_states

    def _get_feat_extract_output_lengths(self, input_lengths: torch.Tensor):
        """
        Computes the output length of the convolutional layers and the output length
        of the audio encoder
        """
        input_lengths = (input_lengths - 1) // 2 + 1
        output_lengths = (input_lengths - 2) // 2 + 1
        return input_lengths, output_lengths

_get_feat_extract_output_lengths

_get_feat_extract_output_lengths(input_lengths: Tensor)

Computes the output length of the convolutional layers and the output length of the audio encoder

Source code in vllm/model_executor/models/audioflamingo3.py
def _get_feat_extract_output_lengths(self, input_lengths: torch.Tensor):
    """
    Computes the output length of the convolutional layers and the output length
    of the audio encoder
    """
    input_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (input_lengths - 2) // 2 + 1
    return input_lengths, output_lengths

AudioFlamingo3FeatureInputs

Bases: TensorSchema

Dimensions
  • num_chunks: Number of audio chunks (flattened)
  • nmb: Number of mel bins
  • num_audios: Number of original audio files
Source code in vllm/model_executor/models/audioflamingo3.py
class AudioFlamingo3FeatureInputs(TensorSchema):
    """
    Dimensions:
        - num_chunks: Number of audio chunks (flattened)
        - nmb: Number of mel bins
        - num_audios: Number of original audio files
    """

    type: Literal["audio_features"]
    input_features: Annotated[
        torch.Tensor | list[torch.Tensor],
        TensorShape("num_chunks", "nmb", 3000),
    ]

    feature_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("num_chunks", 3000),
    ]

    chunk_counts: Annotated[
        torch.Tensor,
        TensorShape("num_audios"),
    ]

AudioFlamingo3ForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP, SupportsLoRA

AudioFlamingo3 model for conditional generation.

This model integrates a Whisper-based audio encoder with a Qwen2 language model. It supports multi-chunk audio processing.

Source code in vllm/model_executor/models/audioflamingo3.py
@MULTIMODAL_REGISTRY.register_processor(
    AudioFlamingo3MultiModalProcessor,
    info=AudioFlamingo3ProcessingInfo,
    dummy_inputs=AudioFlamingo3DummyInputsBuilder,
)
class AudioFlamingo3ForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
):
    """
    AudioFlamingo3 model for conditional generation.

    This model integrates a Whisper-based audio encoder with a Qwen2 language model.
    It supports multi-chunk audio processing.
    """

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config
        self.quant_config = quant_config

        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = AudioFlamingo3Encoder(
                config.audio_config,
            )
            self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> AudioFlamingo3Inputs | None:
        input_features = kwargs.pop("input_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
        feature_attention_mask = kwargs.pop("feature_attention_mask", None)
        chunk_counts = kwargs.pop("chunk_counts", None)

        if input_features is None and audio_embeds is None:
            return None

        if audio_embeds is not None:
            return AudioFlamingo3EmbeddingInputs(
                type="audio_embeds", audio_embeds=audio_embeds
            )

        if input_features is not None:
            return AudioFlamingo3FeatureInputs(
                type="audio_features",
                input_features=input_features,
                feature_attention_mask=feature_attention_mask,
                chunk_counts=chunk_counts,
            )

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
        self, audio_input: AudioFlamingo3Inputs
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
        if audio_input["type"] == "audio_embeds":
            audio_embeds = audio_input["audio_embeds"]
            return tuple(audio_embeds)

        input_features = audio_input["input_features"]
        feature_attention_mask = audio_input["feature_attention_mask"]
        chunk_counts = audio_input.get("chunk_counts")

        if isinstance(input_features, list):
            input_features = torch.cat(input_features, dim=0)
            feature_attention_mask = torch.cat(feature_attention_mask, dim=0)

        if chunk_counts is None:
            chunk_counts = [1] * input_features.shape[0]
        elif isinstance(chunk_counts, torch.Tensor):
            chunk_counts = chunk_counts.tolist()
        elif (
            isinstance(chunk_counts, list)
            and chunk_counts
            and isinstance(chunk_counts[0], torch.Tensor)
        ):
            chunk_counts = [c.item() for c in chunk_counts]

        # Calculate output lengths
        input_lengths = feature_attention_mask.sum(-1)
        # Conv downsampling
        conv_lengths = (input_lengths - 1) // 2 + 1
        # AvgPool downsampling
        audio_output_lengths = (conv_lengths - 2) // 2 + 1

        batch_size, _, max_mel_seq_len = input_features.shape

        # Calculate max_seq_len after convs (before pooling) for attention mask
        max_seq_len = (max_mel_seq_len - 1) // 2 + 1

        # Create a sequence tensor of shape (batch_size, max_seq_len)
        seq_range = (
            torch.arange(
                0,
                max_seq_len,
                dtype=conv_lengths.dtype,
                device=conv_lengths.device,
            )
            .unsqueeze(0)
            .expand(batch_size, max_seq_len)
        )
        lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len)
        # Create mask
        padding_mask = seq_range >= lengths_expand

        audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
            batch_size, 1, max_seq_len, max_seq_len
        )
        audio_attention_mask = audio_attention_mask_.to(
            dtype=self.audio_tower.conv1.weight.dtype,
            device=self.audio_tower.conv1.weight.device,
        )
        audio_attention_mask[audio_attention_mask_] = float("-inf")

        # Forward pass
        audio_features = self.audio_tower(
            input_features, attention_mask=audio_attention_mask
        )

        # Project
        audio_features = self.multi_modal_projector(audio_features)

        # Masking after pooling
        num_audios, max_audio_tokens, embed_dim = audio_features.shape
        audio_output_lengths = audio_output_lengths.unsqueeze(1)
        audio_features_mask = (
            torch.arange(max_audio_tokens)
            .expand(num_audios, max_audio_tokens)
            .to(audio_output_lengths.device)
            < audio_output_lengths
        )
        masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)

        # Split to tuple of embeddings for individual audio input.
        chunk_embeddings = torch.split(
            masked_audio_features, audio_output_lengths.flatten().tolist()
        )

        grouped_embeddings = []
        current_idx = 0
        for count in chunk_counts:
            audio_chunks = chunk_embeddings[current_idx : current_idx + count]
            grouped_embeddings.append(torch.cat(audio_chunks, dim=0))
            current_idx += count
        return tuple(grouped_embeddings)

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return []
        masked_audio_features = self._process_audio_input(audio_input)
        return masked_audio_features

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/audioflamingo3.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model.",
        connector="multi_modal_projector.",
        tower_model="audio_tower.",
    )