MTP Model 模型适配指南【免费下载链接】cann-recipes-infer本项目针对LLM与多模态模型推理业务中的典型模型、加速算法提供基于CANN平台的优化样例项目地址: https://gitcode.com/cann/cann-recipes-infer概述MTP (Multi-Token Prediction) 流程已接入框架中。相关实现文件和文档如下实现文件: mtp_worker.py - MTPWorker 类实现文档文件: mtp_design.md - MTP 执行机制设计文档若想在框架上实现新模型的 MTP 特性需要自定义 MTP 类。重要提示框架实现的投机算法需要从主模型传递prev_hidden_states到 MTP 模型MTP 模型需支持该输入接口;推荐在已有的主模型上搭建 MTP 模型复用主模型的lm_head、embed_tokens等组件框架会默认处理权重共享。本指南以DeepSeek-R1中的 MTP 实现为例介绍如何在 DeepSeekV3ForCausalLM 主模型基础上搭建 MTP 模型。2. 核心类组成结构下面这张图只关注 DeepSeek-R1中 4 个核心类之间的继承和组合关系。DeepseekV3ModelMTPLayer只是“额外 MTP 层容器”它不负责embed token也不负责最终 logits。DeepseekV3ModelMTP复用了ForCausalLM的外层接口形态但内部 model 已经 换成了DeepseekV3ModelMTPLayer。2.1 DeepseekV3ModelMTPLayerMTP 专属的 Transformer 层容器继承自DeepseekV3Model。成员变量类型说明embed_tokensNone复用主模型的 lm_headmtp_start_layer_idxintMTP 层起始索引 config.num_hidden_layerslayersModuleDictMTP 专属的 decoder 层集合layers 结构:Key:num_hidden_layers i(如 60, 61, ...)Value:DeepseekV3DecoderLayer实例数量:config.num_nextn_predict_layersDeepseekV3ModelMTPLayer 推理流程DeepseekV3ModelMTPLayer是 MTP 层的容器其 forward 函数根据mtp_layer_idx参数选择并执行指定的 MTP decoder 层。流程图输入: hidden_states, kv_len, actual_seq_lengths_kv, cos_sin, ..., mtp_layer_idx │ ├─ 根据索引获取指定 MTP 层 │ └─ layer get_layer(mtp_layer_idx) │ └─ 返回 layers[mtp_start_layer_idx mtp_layer_idx] │ └─ 调用该层的 forward 函数 └─ return layer.forward(hidden_states, kv_len, actual_seq_lengths_kv, ...) ├─ Self-Attention 计算 ├─ MoE 前馈网络计算 └─ 残差连接 层归一化 输出: hidden_states关键代码def forward( self, hidden_states: torch.Tensor, mtp_layer_idx: Optional[int] 0, # 指定执行哪个 MTP 层 ... ) - torch.Tensor: # 根据索引获取对应的 MTP decoder 层 layer self.get_layer(mtp_layer_idx) # 调用该层的 forward 函数 return layer.forward( hidden_states, ... )关键点mtp_layer_idx从 0 开始对应config.num_hidden_layers mtp_layer_idx层MTP decoder 层的结构与主模型的 decoder 层相同2.2 DeepseekV3ModelMTPMTP 模型主类继承自DeepseekV3ForCausalLM。成员变量类型作用is_mtpbool TrueMTP 模式标志modelDeepseekV3ModelMTPLayerMTP transformer 层lm_headNone复用主模型的 lm_headrotary_embDeepseekV3YarnRotaryEmbedding位置编码层shared_head_normDeepseekV3RMSNorm共享头归一化层enormDeepseekV3RMSNorm当前帧 hidden state 归一化hnormDeepseekV3RMSNorm上一帧 hidden state 归一化eh_projReplicatedLinear特征融合:[h_t, h_{t-1}] → hDeepseekV3ModelMTP 推理流程MTP 模型主类的推理流程包含特征融合、位置编码获取、多层计算等步骤。流程图输入: input_ids, prev_hidden_states (上一token输出) │ ├─ step 1: 获取 embeddings │ └─ calc_input_embeddings() │ ├─ 复用主模型的 embed_tokens │ └─ 返回 hidden_states │ ├─ step 2: 归一化 │ ├─ hidden_states_e enorm(hidden_states) │ └─ prev_hidden_states_h hnorm(prev_hidden_states) │ ├─ step 3: 特征融合 │ ├─ hidden_states_fused concat([hidden_states_e, prev_hidden_states_h]) │ └─ hidden_states eh_proj(hidden_states_fused) │ ├─ step 4: 获取位置编码 │ └─ cos_sin rotary_emb(hidden_states, kv_len, ...) │ ├─ step 5: Transformer 层计算 │ └─ 遍历所有 MTP 层 │ ├─ for i in range(num_nextn_predict_layers): │ │ └─ hidden_states ModelMTPLayer(mtp_layer_idxi)(hidden_states, ...) │ └─ 逐层执行 MTP 专属的 decoder 层 │ ├─ step 6: 共享头归一化 │ └─ prev_hidden_states, _ shared_head_norm(hidden_states, residual) │ └─ step 7: 输出 logits └─ logits forward_lm_head(prev_hidden_states, ...)关键代码def forward( self, input_ids: torch.Tensor, prev_hidden_states: torch.Tensor, forward_metadata: ForwardMetaData, ... ): is_prefill forward_metadata.is_prefill kv_len forward_metadata.kv_len # Step 1: 获取 embeddings (复用主模型) hidden_states self.model.calc_input_embeddings(input_ids, ...) # Step 2: 归一化 hidden_states self.enorm(hidden_states) prev_hidden_states self.hnorm(prev_hidden_states) # Step 3: 特征融合 hidden_states_fused concat([hidden_states, prev_hidden_states], dim-1) hidden_states self.eh_proj(hidden_states_fused) # Step 4: 获取位置编码 cos_sin self.rotary_emb(hidden_states, kv_len, ...) # Step 5: Transformer 层计算 residual None for i in range(self.config.num_nextn_predict_layers): residual, hidden_states self.model.forward( hidden_states, kv_len, ..., mtp_layer_idxi, past_residualresidual ) # Step 6: 共享头归一化 prev_hidden_states, _ self.shared_head_norm(hidden_states, residual) # Step 7: 输出 logits (复用主模型) logits self.forward_lm_head(prev_hidden_states, ...) return logits, prev_hidden_statesMTP 独有权重映射MTP 模型需要额外加载的权重如下当出现checkpoint权重名与模型参数名不一致的情况需要在 MTP 模型的load_weights函数进行映射checkpoint权重名模型参数名说明shared_head.normshared_head_norm共享头归一化enormenorme 分支归一化hnormhnormh 分支归一化eh_projeh_proj特征融合投影注意:MTP 层索引从num_hidden_layers开始例如主模型有 60 层MTP 有 1 层则 MTP 层为 60embed_tokens.weight和lm_head.weight可以无需加载复用主模型3. 实现步骤在框架中实现模型的 MTP 特性需要完成以下三个步骤3.1 步骤一定义 MTP 类在模型文件中定义 MTP 相关类。以 DeepSeek-R1 为例在models/deepseek_r1/models/modeling_deepseek.py中定义核心类DeepseekV3ModelMTPLayer- MTP Transformer 层容器DeepseekV3ModelMTP- MTP 模型主类MTP 模型关键特征class DeepseekV3ModelMTP(DeepseekV3ForCausalLM): is_mtp True # MTP 模式标志 model DeepseekV3ModelMTPLayer # MTP 专属层 lm_head None # 复用主模型 rotary_emb ... # 位置编码 shared_head_norm ... # 共享头归一化 enorm ... # 当前帧归一化 hnorm ... # 上一帧归一化 eh_proj ... # 特征融合投影3.2 步骤二注册 MTP 模型将 MTP 模型类注册到框架的模型字典中。在executor/core/entrypoints/support_models.py中添加from models.deepseek_r1.models.modeling_deepseek import DeepseekV3ForCausalLM, DeepseekV3ModelMTP from models.deepseek_r1.models.configuration_deepseek import DeepseekV3Config model_dict { deepseek_r1: (DeepseekV3ForCausalLM, DeepseekV3ModelMTP, DeepseekV3Config) }model_dict 结构说明第一个元素: 主模型类 (DeepseekV3ForCausalLM)第二个元素: MTP 模型类 (DeepseekV3ModelMTP)第三个元素: 配置类 (DeepseekV3Config)3.3 步骤三使能 MTP 的配置在 YAML 配置文件中设置model_config中的next_n参数来启用 MTPmodel_config: next_n: 3 # 每步预测的推测 token 数量 0 时启用 MTP框架自动处理框架检查next_n 0满足条件时自动使用 MTP 模型框架通过share_weights_from_main_model自动共享lm_head和embed_tokens权重【免费下载链接】cann-recipes-infer本项目针对LLM与多模态模型推理业务中的典型模型、加速算法提供基于CANN平台的优化样例项目地址: https://gitcode.com/cann/cann-recipes-infer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考