Transformer 模型结构根据前面介绍的所有组件结合起来就是一个完整的 Transformer 结构了。上图为论文《Attention is all you need》原文配图LayerNorm 是放在 Attention 层之后也就是“Post-Norm”结构但是在其发布的源码中LayerNormer 是放在 Attention 层之前也就是“Pre-Norm”。实际中Pre-Norm 结构可以使 loss 更稳定所以目前 LLM 一般采用“Pre-Norm”即输入先归一化Attention 层输入更稳定。“Post-Norm”的话Attention 输出可能很大。classTransformer(nn.Module):整体模型def__init__(self,args):super().__init__()# 必须输入词表大小和 block sizeassertargs.vocab_sizeisnotNoneassertargs.block_sizeisnotNoneself.argsargs self.transformernn.ModuleDict(dict(wtenn.Embedding(args.vocab_size,args.n_embd),wpePositionalEncoding(args),dropnn.Dropout(args.dropout),encoderEncoder(args),decoderDecoder(args),))# 最后的线性层输入是 n_embd输出是词表大小self.lm_headnn.Linear(args.n_embd,args.vocab_size,biasFalse)# 初始化所有的权重self.apply(self._init_weights)# 查看所有参数的数量print(number of parameters: %.2fM%(self.get_num_params()/1e6,))统计所有参数的数量defget_num_params(self,non_embeddingFalse):# non_embedding: 是否统计 embedding 的参数n_paramssum(p.numel()forpinself.parameters())# 如果不统计 embedding 的参数就减去ifnon_embedding:n_params-self.transformer.wte.weight.numel()returnn_params初始化权重def_init_weights(self,module):# 线性层和 Embedding 层初始化为正则分布ifisinstance(module,nn.Linear):torch.nn.init.normal_(module.weight,mean0.0,std0.02)ifmodule.biasisnotNone:torch.nn.init.zeros_(module.bias)elifisinstance(module,nn.Embedding):torch.nn.init.normal_(module.weight,mean0.0,std0.02)前向计算函数defforward(self,idx,targetsNone):# 输入为 idx维度为 (batch size, sequence length, 1)targets 为目标序列用于计算 lossdeviceidx.device b,tidx.size()asserttself.args.block_size,f不能计算该序列该序列长度为{t}, 最大序列长度只有{self.args.block_size}# 通过 self.transformer# 首先将输入 idx 通过 Embedding 层得到维度为 (batch size, sequence length, n_embd)print(idx,idx.size())# 通过 Embedding 层tok_embself.transformer.wte(idx)print(tok_emb,tok_emb.size())# 然后通过位置编码pos_embself.transformer.wpe(tok_emb)# 再进行 Dropoutxself.transformer.drop(pos_emb)# 然后通过 Encoderprint(x after wpe:,x.size())enc_outself.transformer.encoder(x)print(enc_out:,enc_out.size())# 再通过 Decoderxself.transformer.decoder(x,enc_out)print(x after decoder:,x.size())iftargetsisnotNone:# 训练阶段如果我们给了 targets就计算 loss# 先通过最后的 Linear 层得到维度为 (batch size, sequence length, vocab size)logitsself.lm_head(x)# 再跟 targets 计算交叉熵lossF.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1),ignore_index-1)else:# 推理阶段我们只需要 logitsloss 为 None# 取 -1 是只取序列中的最后一个作为输出logitsself.lm_head(x[:,[-1],:])# note: using list [-1] to preserve the time dimlossNonereturnlogits,loss Transformer整体结构 1. Transformer的组成部分: 输入层: - wte: 词嵌入层 - wpe: 位置编码层 - drop: Dropout层 编码器: - encoder: 编码器 解码器: - decoder: 解码器 输出层: - lm_head: 线性层输出词表概率 2. 数据流动: 输入 idx (batch_size, seq_len) ↓ wte: 词嵌入 (batch_size, seq_len, n_embd) ↓ wpe: 位置编码 (batch_size, seq_len, n_embd) ↓ drop: Dropout (batch_size, seq_len, n_embd) ↓ encoder: 编码器 (batch_size, seq_len, n_embd) ↓ decoder: 解码器 (batch_size, seq_len, n_embd) ↓ lm_head: 线性层 (batch_size, seq_len, vocab_size) ↓ logits (batch_size, seq_len, vocab_size) __init__方法详解 1. 参数检查: assert args.vocab_size is not None assert args.block_size is not None 作用: - 确保vocab_size词表大小已设置 - 确保block_size最大序列长度已设置 - 如果没有设置会报错 2. 创建组件: self.transformer nn.ModuleDict(dict( wte nn.Embedding(args.vocab_size, args.n_embd), wpe PositionalEncoding(args), drop nn.Dropout(args.dropout), encoder Encoder(args), decoder Decoder(args), )) 解释: - wte: 词嵌入层 * 输入: token索引 (batch_size, seq_len) * 输出: 词向量 (batch_size, seq_len, n_embd) * 参数数量: vocab_size × n_embd - wpe: 位置编码层 * 输入: 词向量 (batch_size, seq_len, n_embd) * 输出: 添加位置编码后的向量 (batch_size, seq_len, n_embd) - drop: Dropout层 * 输入: 向量 (batch_size, seq_len, n_embd) * 输出: Dropout后的向量 (batch_size, seq_len, n_embd) * 作用: 防止过拟合 - encoder: 编码器 * 输入: 向量 (batch_size, seq_len, n_embd) * 输出: 编码后的向量 (batch_size, seq_len, n_embd) - decoder: 解码器 * 输入: 向量 编码器输出 * 输出: 解码后的向量 (batch_size, seq_len, n_embd) 3. 输出层: self.lm_head nn.Linear(args.n_embd, args.vocab_size, biasFalse) 解释: - 输入: 解码器输出 (batch_size, seq_len, n_embd) - 输出: 词表概率 (batch_size, seq_len, vocab_size) - biasFalse: 不使用偏置 - 参数数量: n_embd × vocab_size 4. 权重初始化: self.apply(self._init_weights) 解释: - 对所有线性层和Embedding层进行初始化 - 使用正态分布初始化 - mean0.0, std0.02 5. 参数统计: print(number of parameters: %.2fM % (self.get_num_params()/1e6,)) 解释: - 统计所有参数的数量 - 除以1e6转换为百万M - 例如: 10M 1000万参数 forward方法详解 1. 输入参数: def forward(self, idx, targetsNone): 参数: - idx: 输入序列 * 形状: (batch_size, seq_len) * 内容: token索引 - targets: 目标序列可选 * 形状: (batch_size, seq_len) * 内容: 目标token索引 * 用途: 计算loss 2. 参数检查: device idx.device b, t idx.size() assert t self.args.block_size 解释: - device: 获取设备CPU或GPU - b: batch_size - t: seq_len序列长度 - 检查序列长度是否超过最大长度 3. 词嵌入: tok_emb self.transformer.wte(idx) 数据变化: - 输入: idx (batch_size, seq_len) - 输出: tok_emb (batch_size, seq_len, n_embd) 例子: - idx: [[1, 2, 3], [4, 5, 6]] - tok_emb: [[[0.1, 0.2, ...], [0.3, 0.4, ...], [0.5, 0.6, ...]], ...] - 每个token索引转换为一个n_embd维的向量 4. 位置编码: pos_emb self.transformer.wpe(tok_emb) 数据变化: - 输入: tok_emb (batch_size, seq_len, n_embd) - 输出: pos_emb (batch_size, seq_len, n_embd) 作用: - 给每个位置添加位置信息 - tok_emb 位置编码 pos_emb 5. Dropout: x self.transformer.drop(pos_emb) 数据变化: - 输入: pos_emb (batch_size, seq_len, n_embd) - 输出: x (batch_size, seq_len, n_embd) 作用: - 随机丢弃一些神经元 - 防止过拟合 6. 编码器: enc_out self.transformer.encoder(x) 数据变化: - 输入: x (batch_size, seq_len, n_embd) - 输出: enc_out (batch_size, seq_len, n_embd) 作用: - 编码输入序列 - 提取特征 7. 解码器: x self.transformer.decoder(x, enc_out) 数据变化: - 输入1: x (batch_size, seq_len, n_embd) - 输入2: enc_out (batch_size, seq_len, n_embd) - 输出: x (batch_size, seq_len, n_embd) 作用: - 解码输入序列 - 结合编码器输出 8. 输出层: if targets is not None: # 训练阶段 logits self.lm_head(x) loss F.cross_entropy(...) else: # 推理阶段 logits self.lm_head(x[:, [-1], :]) loss None 训练阶段: - 输入: x (batch_size, seq_len, n_embd) - 输出: logits (batch_size, seq_len, vocab_size) - 计算loss: 交叉熵损失 推理阶段: - 输入: x[:, [-1], :] (batch_size, 1, n_embd) - 只取最后一个时间步 - 输出: logits (batch_size, 1, vocab_size) - loss None 9. 返回值: return logits, loss logits: 词表概率 loss: 损失训练阶段有值推理阶段为None 完整的数据流动示例 1. 参数: batch_size: 2 seq_len: 5 vocab_size: 10000 n_embd: 512 2. 数据流动: 步骤1: 输入 - idx: (2, 5) - 例如: [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]] 步骤2: 词嵌入 - tok_emb: (2, 5, 512) - 每个token索引转换为512维向量 步骤3: 位置编码 - pos_emb: (2, 5, 512) - tok_emb 位置编码 步骤4: Dropout - x: (2, 5, 512) - 随机丢弃一些神经元 步骤5: 编码器 - enc_out: (2, 5, 512) - 编码输入序列 步骤6: 解码器 - x: (2, 5, 512) - 解码输入序列 编码器输出 步骤7: 输出层 - logits: (2, 5, 10000) - 每个位置输出词表概率 3. 训练 vs 推理: 训练阶段: - 输出: logits (2, 5, 10000) - loss: 交叉熵损失 - 用途: 更新模型参数 推理阶段: - 输出: logits (2, 1, 10000) - loss: None - 用途: 生成下一个词 4. 关键点: - wte: 词嵌入层将token索引转换为向量 - wpe: 位置编码层添加位置信息 - encoder: 编码器提取特征 - decoder: 解码器生成输出 - lm_head: 输出层输出词表概率 - 训练时: 输出所有位置的logits计算loss - 推理时: 只输出最后一个位置的logits用于生成输入 idx (batch_size, seq_len) ↓ wte: 词嵌入 (batch_size, seq_len, n_embd) ↓ wpe: 位置编码 (batch_size, seq_len, n_embd) ↓ drop: Dropout (batch_size, seq_len, n_embd) ↓ encoder: 编码器 (batch_size, seq_len, n_embd) ↓ decoder: 解码器 (batch_size, seq_len, n_embd) ↓ lm_head: 线性层 (batch_size, seq_len, vocab_size) ↓ logits (batch_size, seq_len, vocab_size)关键组件训练 vs 推理各模块作用1、词嵌入层将 token 索引转换为向量2、位置编码添加位置信息3、编码器提取特征4、解码器生成输出5、输出层输出词表概率6、权重初始化正态分布