@MULTIMODAL_REGISTRY.register_processor(
Ernie4_5VLMultiModalProcessor,
info=Ernie4_5_VLProcessingInfo,
dummy_inputs=Ernie4_5_VLDummyInputsBuilder,
)
class Ernie4_5_VLMoeForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
# model.resampler_model.-> language_model.model.resampler_model.
# language_model.model.resampler_model. -> resampler_model.
"language_model.model.resampler_model.": "resampler_model.",
},
# resampler_weight_mappings
orig_to_new_substr={
"spatial_linear.0.": "spatial_linear1.",
"spatial_linear.2.": "spatial_linear2.",
"spatial_linear.3.": "spatial_norm.",
"temporal_linear.0.": "temporal_linear1.",
"temporal_linear.2.": "temporal_linear2.",
"temporal_linear.3.": "temporal_norm.",
},
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
if modality.startswith("video"):
return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
raise ValueError("Only image or video modality is supported")
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_model = Ernie4_5_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.resampler_model = VariableResolutionResamplerModel(
self.config.pixel_hidden_size,
self.config.hidden_size,
self.config.spatial_conv_size,
self.config.temporal_conv_size,
config=self.config,
prefix=maybe_prefix(prefix, "resampler_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = Ernie4_5_VLMoeForCausalLM(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
if getattr(self.config, "im_patch_id", None):
visual_token_ids = [
token_id
for token_id in [
self.config.im_patch_id,
getattr(self.config, "image_start_token_id", None),
getattr(self.config, "image_end_token_id", None),
getattr(self.config, "video_start_token_id", None),
getattr(self.config, "video_end_token_id", None),
]
if token_id is not None
]
self._visual_token_ids_tensor_cache = torch.tensor(
visual_token_ids, dtype=torch.long
)
else:
self._visual_token_ids_tensor_cache = None
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
"""compute logits"""
return self.language_model.compute_logits(hidden_states)
def _vision_forward(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
if grid_thw is not None:
grid_thw = grid_thw[grid_thw > 0]
if grid_thw.numel() % 3 != 0:
raise ValueError(
f"grid_thw has {grid_thw.numel()} elements after filtering,"
"which is not divisible by 3."
)
grid_thw = grid_thw.reshape(-1, 3)
# example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]]
grid_thw = F.pad(
torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0),
[1, 0, 0, 0],
value=1,
)
image_features = self.vision_model(pixel_values, grid_thw)
return image_features
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
"""Set mask for visual tokens (image/video patches and delimiters)."""
if self._visual_token_ids_tensor_cache is None:
self.visual_token_mask = None
return
# Create tensor on the correct device
visual_token_ids_tensor = self._visual_token_ids_tensor_cache.to(
device=input_ids.device,
dtype=input_ids.dtype,
)
self.visual_token_mask = torch.isin(input_ids, visual_token_ids_tensor).reshape(
-1, 1
)
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
hf_config = self.config
image_token_id = hf_config.im_patch_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
llm_pos_ids_list: list = []
if image_grid_thw or video_grid_thw:
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = image_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_conv_size,
w // spatial_conv_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = video_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t // temporal_conv_size,
h // spatial_conv_size,
w // spatial_conv_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Ernie4_5_VLImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
if pixel_values is None:
return None
if pixel_values is not None:
return Ernie4_5_VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self, **kwargs: object
) -> Ernie4_5_VLVideoInputs | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
if pixel_values_videos is None:
return None
if pixel_values_videos is not None:
return Ernie4_5_VLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
def _process_image_input(
self, image_input: Ernie4_5_VLImageInputs
) -> tuple[torch.Tensor, ...]:
grid_thw = image_input["image_grid_thw"]
assert grid_thw.ndim == 2
pixel_values = image_input["pixel_values"].type(self.vision_model.dtype)
image_features = self._vision_forward(
pixel_values=pixel_values, grid_thw=grid_thw
)
image_embeds = self.resampler_model(image_features, grid_thw)
merge_size = self.vision_model.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return image_embeds.split(sizes.tolist())
def _process_video_input(
self, video_input: Ernie4_5_VLVideoInputs
) -> tuple[torch.Tensor, ...]:
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
pixel_values_videos = video_input["pixel_values_videos"].type(
self.vision_model.dtype
)
video_features = self._vision_forward(
pixel_values=pixel_values_videos, grid_thw=grid_thw
)
video_embeds = self.resampler_model(video_features, grid_thw)
merge_size = self.vision_model.spatial_merge_size
sizes = (
(grid_thw.prod(-1) // self.config.temporal_conv_size)
// merge_size
// merge_size
)
return video_embeds.split(sizes.tolist())
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if (
input_key in ("pixel_values", "image_embeds")
and "images" not in modalities
):
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
if (
input_key in ("pixel_values_videos", "video_embeds")
and "videos" not in modalities
):
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
return modalities
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
# The result multimodal_embeddings is tuple of tensors, with each
# tensor corresponding to a multimodal data item (image or video).
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += tuple(image_embeddings)
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
# This is to satisfy the type checker for each overload
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
):
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
if self.visual_token_mask is not None:
if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]:
padding_len = inputs_embeds.shape[0] - self.visual_token_mask.shape[0]
# right pad False
pad = torch.zeros(
(padding_len, self.visual_token_mask.shape[1]),
dtype=self.visual_token_mask.dtype,
device=self.visual_token_mask.device,
)
self.visual_token_mask = torch.cat([self.visual_token_mask, pad], dim=0)
forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
hidden_states = self.language_model.model(
**forward_kwargs,
**kwargs,
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)