@MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor,
info=DeepseekOCRProcessingInfo,
dummy_inputs=DeepseekOCRDummyInputsBuilder,
)
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# map prefix for language backbone
"model.embed_tokens.": "language_model.model.embed_tokens.",
"model.layers.": "language_model.model.layers.",
"model.norm.": "language_model.model.norm.",
"lm_head.": "language_model.lm_head.",
# remove "model." prefix for other components
"model.": "",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: DeepseekVLV2Config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.vision_config = config.vision_config
self.projector_config = config.projector_config
self.text_config = config.text_config
model_config = vllm_config.model_config
tokenizer = cached_tokenizer_from_config(model_config)
self.image_token_id = tokenizer.vocab[_IMAGE_TOKEN]
with self._mark_tower_model(vllm_config, "image"):
self.sam_model = build_sam_vit_b()
clip_vision_config = CLIPVisionConfig(
hidden_size=1024,
intermediate_size=4096,
num_attention_heads=16,
num_hidden_layers=24,
image_size=224,
patch_size=14,
projection_dim=512,
layer_norm_eps=1e-5,
)
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.projector = MlpProjector(self.projector_config)
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# special token for image token sequence format
n_embed = self.projector_config.n_embed
embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
# This is a typo in original implementation
self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
else:
raise ValueError(
f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=self.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> DeepseekOCRImagePixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
if pixel_values is None or torch.sum(pixel_values).item() == 0:
return None
base_size = self.vision_config.image_size
return DeepseekOCRImagePixelInputs(
type="pixel_values",
data=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
resolve_bindings={
"base_size": base_size,
},
)
def _encode_global_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
global_features_1 = self.sam_model(image_tensor)
global_features_2 = self.vision_model(image_tensor, global_features_1)
features = torch.cat(
(
global_features_2[:, 1:],
global_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
features = self.projector(features)
_, hw, dim = features.shape
side = int(hw**0.5)
features = features.view(side, side, dim)
newline = self.image_newline[None, None, :].expand(side, 1, dim)
features = torch.cat([features, newline], dim=1)
return features.view(-1, dim)
def _encode_local_features(
self, patches: torch.Tensor, crop_shape: torch.Tensor
) -> torch.Tensor | None:
if torch.sum(patches).item() == 0:
return None
local_features_1 = self.sam_model(patches)
local_features_2 = self.vision_model(patches, local_features_1)
features = torch.cat(
(
local_features_2[:, 1:],
local_features_1.flatten(2).permute(0, 2, 1),
),
dim=-1,
)
features = self.projector(features)
_, hw, dim = features.shape
patch_side = int(hw**0.5)
width_tiles = int(crop_shape[0].item())
height_tiles = int(crop_shape[1].item())
features = (
features.view(height_tiles, width_tiles, patch_side, patch_side, dim)
.permute(0, 2, 1, 3, 4)
.reshape(height_tiles * patch_side, width_tiles * patch_side, dim)
)
newline = self.image_newline[None, None, :].expand(
height_tiles * patch_side, 1, dim
)
features = torch.cat([features, newline], dim=1)
return features.view(-1, dim)
def _pixel_values_to_embedding(
self,
pixel_values: torch.Tensor,
images_crop: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
images_in_this_batch = []
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
images_crop = images_crop.split(patches_per_image.tolist())
for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx]
image_ori = pixel_values[[jdx]]
crop_shape = images_spatial_crop[jdx]
global_features = self._encode_global_features(image_ori)
local_features = self._encode_local_features(patches, crop_shape)
if local_features is not None:
combined = torch.cat(
[local_features, global_features, self.view_seperator[None, :]],
dim=0,
)
else:
combined = torch.cat(
[global_features, self.view_seperator[None, :]], dim=0
)
images_in_this_batch.append(combined)
return images_in_this_batch
def _process_image_input(
self, image_input: DeepseekOCRImagePixelInputs
) -> torch.Tensor:
pixel_values = image_input.data
images_crop = image_input.images_crop
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)
vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
)
return vision_features
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
):
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return autoloaded_weights
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="projector",
tower_model=["sam_model", "vision_model"],
)