[model] support XiaomiMiMo/MiMo-V2.5#9273
[model] support XiaomiMiMo/MiMo-V2.5#9273Jintao-Huang wants to merge 1 commit intomodelscope:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the MiMo-V2.5 model by adding the necessary constants, model architecture registration, and template definitions. The implementation includes a custom MiMoV2Loader and MiMoV2Template to handle specific model requirements like video key renaming and thinking mode support. Feedback was provided to improve the robustness of the _post_encode method in the template and to ensure proper handling of AWQ-quantized models during model loading.
| def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||
| if not self.is_training: | ||
| # During inference, rename key to match MiMo-V2.5 forward signature | ||
| if 'pixel_values_videos' in inputs: | ||
| inputs['video_pixel_values'] = inputs.pop('pixel_values_videos') | ||
| return inputs | ||
| # For training, compute embeddings manually | ||
| input_ids = inputs['input_ids'] | ||
| base_model = self.get_base_model(model) | ||
| inputs_embeds = base_model.model.embed_tokens(input_ids) | ||
| inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config) | ||
| return {'inputs_embeds': inputs_embeds} |
There was a problem hiding this comment.
The _post_encode implementation for MiMoV2Template has several improvement opportunities in the training branch:
- Robustness: It assumes
base_model.model.embed_tokensexists. Using a check forlanguage_model(similar to the parentQwen2VLTemplate) makes it more resilient to different model architectures. - Consistency: The inference branch returns the full
inputsdictionary, while the training branch returns a new dictionary containing onlyinputs_embeds. While the framework might merge these, it is safer and more consistent to update theinputsdictionary in place and return it. - Ambiguity: When
inputs_embedsis provided,input_idsshould ideally be removed from the inputs to avoid ambiguity in the model's forward call.
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not self.is_training:
# During inference, rename key to match MiMo-V2.5 forward signature
if 'pixel_values_videos' in inputs:
inputs['video_pixel_values'] = inputs.pop('pixel_values_videos')
return inputs
# For training, compute embeddings manually
input_ids = inputs['input_ids']
base_model = self.get_base_model(model)
if hasattr(base_model.model, 'embed_tokens'):
inputs_embeds = base_model.model.embed_tokens(input_ids)
else:
inputs_embeds = base_model.model.language_model.embed_tokens(input_ids)
inputs_embeds = self._get_inputs_embeds_hf(inputs_embeds, inputs, model.visual, self.processor, model.config)
inputs['inputs_embeds'] = inputs_embeds
inputs.pop('input_ids', None)
return inputs| def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel: | ||
| model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs) | ||
| patch_get_input_embeddings(model.visual, 'patch_embed') | ||
| return model |
There was a problem hiding this comment.
The current implementation of get_model for MiMoV2Loader bypasses the logic in Qwen2VLLoader.get_model to avoid the hardcoded auto_model_cls, but it also misses the necessary check for AWQ-wrapped models. When a model is quantized with AWQ, the actual model components are often nested under a .model attribute. It is safer to replicate the base_model logic to ensure patch_get_input_embeddings is applied to the correct module.
| def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel: | |
| model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs) | |
| patch_get_input_embeddings(model.visual, 'patch_embed') | |
| return model | |
| def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel: | |
| model = ModelLoader.get_model(self, model_dir, config, processor, model_kwargs) | |
| base_model = model.model if 'AWQ' in model.__class__.__name__ else model | |
| patch_get_input_embeddings(base_model.visual, 'patch_embed') | |
| return model |
No description provided.