Skip to content

[model] support XiaomiMiMo/MiMo-V2.5#9273

Open
Jintao-Huang wants to merge 1 commit intomodelscope:mainfrom
Jintao-Huang:support_mimo_v2
Open

[model] support XiaomiMiMo/MiMo-V2.5#9273
Jintao-Huang wants to merge 1 commit intomodelscope:mainfrom
Jintao-Huang:support_mimo_v2

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the MiMo-V2.5 model by adding the necessary constants, model architecture registration, and template definitions. The implementation includes a custom MiMoV2Loader and MiMoV2Template to handle specific model requirements like video key renaming and thinking mode support. Feedback was provided to improve the robustness of the _post_encode method in the template and to ensure proper handling of AWQ-quantized models during model loading.

Comment on lines +526 to +537
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not self.is_training:
# During inference, rename key to match MiMo-V2.5 forward signature
if 'pixel_values_videos' in inputs:
inputs['video_pixel_values'] = inputs.pop('pixel_values_videos')
return inputs
# For training, compute embeddings manually
input_ids = inputs['input_ids']
base_model = self.get_base_model(model)
inputs_embeds = base_model.model.embed_tokens(input_ids)
inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config)
return {'inputs_embeds': inputs_embeds}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _post_encode implementation for MiMoV2Template has several improvement opportunities in the training branch:

  1. Robustness: It assumes base_model.model.embed_tokens exists. Using a check for language_model (similar to the parent Qwen2VLTemplate) makes it more resilient to different model architectures.
  2. Consistency: The inference branch returns the full inputs dictionary, while the training branch returns a new dictionary containing only inputs_embeds. While the framework might merge these, it is safer and more consistent to update the inputs dictionary in place and return it.
  3. Ambiguity: When inputs_embeds is provided, input_ids should ideally be removed from the inputs to avoid ambiguity in the model's forward call.
    def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
        if not self.is_training:
            # During inference, rename key to match MiMo-V2.5 forward signature
            if 'pixel_values_videos' in inputs:
                inputs['video_pixel_values'] = inputs.pop('pixel_values_videos')
            return inputs
        # For training, compute embeddings manually
        input_ids = inputs['input_ids']
        base_model = self.get_base_model(model)
        if hasattr(base_model.model, 'embed_tokens'):
            inputs_embeds = base_model.model.embed_tokens(input_ids)
        else:
            inputs_embeds = base_model.model.language_model.embed_tokens(input_ids)
        inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config)
        inputs['inputs_embeds'] = inputs_embeds
        inputs.pop('input_ids', None)
        return inputs

Comment on lines +872 to +875
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs)
patch_get_input_embeddings(model.visual, 'patch_embed')
return model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of get_model for MiMoV2Loader bypasses the logic in Qwen2VLLoader.get_model to avoid the hardcoded auto_model_cls, but it also misses the necessary check for AWQ-wrapped models. When a model is quantized with AWQ, the actual model components are often nested under a .model attribute. It is safer to replicate the base_model logic to ensure patch_get_input_embeddings is applied to the correct module.

Suggested change
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs)
patch_get_input_embeddings(model.visual, 'patch_embed')
return model
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs)
base_model = model.model if 'AWQ' in model.__class__.__name__ else model
patch_get_input_embeddings(base_model.visual, 'patch_embed')
return model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant