完全开源的语言模型学习记录--TrilinearCIM架构
文章目录在这里插入图片描述一、一段话总结二、思维导图三、详细总结1. 研究动机与问题2. 核心技术方案3. 评估与结果4. 贡献与结论四、关键问题与答案https://arxiv.org/pdf/2604.07628Trilinear Compute-in-Memory Architecture for Energy-Efficient Transformer Acceleration一、一段话总结本文提出TrilinearCIM架构基于双栅铁电场效应晶体管DG‑FeFET通过背栅调制实现三操作数乘累加在无需运行时非易失性存储器NVM重编程的情况下完成Transformer自注意力全流程计算在BERT‑base与ViT‑base上验证相比传统FeFET CIM实现最高46.6%能耗降低、20.4%延迟降低9项GLUE任务中7项精度更优是首个纯NVM核内完成Transformer注意力计算的架构。二、思维导图## **研究背景** - Transformer自注意力O(N²)复杂度、动态Q/K/V - 传统CIM局限双操作数、需反复NVM写入 - 现有方案混合CMOS/降写入但未根除 ## **核心创新** - 器件DG‑FeFET顶栅存权重、背栅动态调制 - 算子三操作数MACA·B·C - 数据流融合投影与注意力、消除中间张量存储 ## **架构设计** - 层级芯片→Tile→PE→SubArray - 阵列无选择器结构、两种三线性配置 - 数字单元Softmax/LayerNorm/GELU专用SFU ## **评估框架** - TransCIM基于NeuroSim扩展、支持精度/PPA建模 - 配置7nm CMOS22nm FeFET、8bit输入权重 ## **实验结果** - 精度GLUE 7/9任务优于双线性CIM - PPA能耗-46.6%、延迟-20.4%、面积37.3% - 消融1b/6b最优精度、32×32最优延迟 ## **优势与局限** - 优势零写入、高能效、长序列友好 - 局限ViT精度下降、需硬件验证三、详细总结1. 研究动机与问题Transformer自注意力生成动态Q/K/V矩阵传统CIM需反复NVM重编程导致延迟高、能耗大、耐久性差。FeFET读写不对称写入延迟50ns、读取10ns写入能耗亚pJ级、读取fJ级。BERT‑base单次推理需**~75.5M次写入**严重限制器件寿命。2. 核心技术方案1器件基础DG‑FeFET顶栅TG非易失性存储静态权重。背栅BG挥发性电压调制沟道电导提供第三操作数通路。电导特性G D S ( V B G ) ≈ G 0 ⋅ ( 1 η B G ⋅ V B G ) G_{DS}(V_{BG})≈G_0·(1η_{BG}·V_{BG})GDS(VBG)≈G0⋅(1ηBG⋅VBG)工作区间G 0 ∈ [ 29 , 69 ] μ S G_0∈[29,69]μSG0∈[29,69]μSη ˉ B G 0.157 V − 1 \bar{η}_{BG}0.157V^{-1}ηˉBG0.157V−1。2三线性CIM原语实现三操作数乘累加YA·B·C替代传统双操作数输入·权重。权重一次性编程动态操作数通过背栅电压传入全程无NVM写入。3注意力数据流缩放Query生成R 1 X ⋅ W Q T / d k R_1X·W_Q^T/\sqrt{d_k}R1X⋅WQT/dk注意力分数合成R 2 R 1 ⋅ W K ⋅ X T R_2R_1·W_K·X^TR2R1⋅WK⋅XTValue聚合R e s u l t S c o r e ⋅ X ⋅ W V T ResultScore·X·W_V^TResultScore⋅X⋅WVT优势消除中间Q/K/V存储缓冲区需求降低**~3倍**。4硬件架构层级芯片级2×2 Tile→Tile级2×2 PE→PE级2×2 Array。阵列无选择器DG‑FeFET交叉阵列列级背栅驱动与DAC。数字单元专用SFU执行Softmax、LayerNorm、GELU。3. 评估与结果1实验配置模型BERT‑base、ViT‑base数据集GLUE、ImageNet/CIFAR。工艺7nm CMOS 22nm FeFET阵列64×64精度8bit输入/权重。2精度结果任务类型对比结果GLUE9项TrilinearCIM在7项优于双线性CIMMNLI3.14%、QNLI3.74%视觉任务TrilinearCIM略低于双线性因背栅DAC量化扭曲注意力峰值3PPA结果BERT‑base序列长度指标双线性三线性变化64延迟(ms)7.636.08-20.4%64能耗(μJ)1522813-46.6%128延迟(ms)8.196.67-18.6%128能耗(μJ)31321889-39.7%-面积--37.3%-TOPS/W9.6813.4739.2%4消融实验子阵列32×32延迟**-40.9%**64×64能效更高。精度1b/6b为最优精度点能耗-37.5%、延迟-26.0%、面积32.4%。序列长度越长三线性零写入优势越显著。4. 贡献与结论首次实现纯NVM核内Transformer注意力计算零运行时重编程。能效、延迟显著提升面积开销可控。更适合NLP长序列场景视觉任务需算法优化。四、关键问题与答案TrilinearCIM的核心创新点是什么答核心创新是基于DG‑FeFET的三操作数CIM原语利用背栅调制提供第三操作数通路将静态权重存于顶栅、动态操作数以背栅电压传入完全消除Transformer注意力推理时的NVM写入同时融合投影与注意力步骤降低缓冲区需求约3倍。TrilinearCIM相比传统CIM的性能收益与代价是什么答收益为最高46.6%能耗降低、20.4%延迟降低、TOPS/W提升39.2%9项GLUE任务7项精度更优代价是37.3%面积增加ViT视觉任务精度因背栅DAC量化有所下降。为何TrilinearCIM在NLP任务表现更好而视觉任务稍弱答NLP任务对噪声鲁棒性强小扰动不改变注意力结果ViT注意力分布稀疏、存在高幅值峰值背栅DAC的均匀量化会扭曲关键峰值且ViT激活通道方差更大放大量化误差导致精度下降。》实战测试随着序列长度以及隐藏层维度的升高虽然内存下降但是时耗明显增高importnumpyasnpimporttime# -----------------------------------------------------------------------------# 优化版传统注意力高维 长序列 专用# -----------------------------------------------------------------------------deftraditional_attention_fast(X,Wq,Wk,Wv,d_k,block512):N,dX.shape outnp.zeros((N,d_k),dtypenp.float16)foriinrange(0,N,block):XiX[i:iblock]QiXi Wq.T KiX Wk.T ViX Wv.T score(Qi Ki.T)/np.sqrt(d_k)attnnp.exp(score)/np.sum(np.exp(score),axis-1,keepdimsTrue)out[i:iblock]attn Vireturnout# -----------------------------------------------------------------------------# 极速版TrilinearCIM 注意力高维最强不生成Q/K/V# -----------------------------------------------------------------------------deftrilinear_attention_fast(X,Wq,Wk,Wv,d_k,block512):N,dX.shape outnp.zeros((N,d_k),dtypenp.float16)foriinrange(0,N,block):XiX[i:iblock]R1i(Xi Wq.T)/np.sqrt(d_k)scoreR1i Wk X.T attnnp.exp(score)/np.sum(np.exp(score),axis-1,keepdimsTrue)out[i:iblock]attn X Wv.Treturnout# -----------------------------------------------------------------------------# 测试长序列 1w 高维度 768/1024# -----------------------------------------------------------------------------if__name____main__:SEQ_LEN10000HIDDEN_DIM128# 你可以改成 512/768/1024 测试D_K64print(f序列长度{SEQ_LEN})print(f隐藏维度{HIDDEN_DIM})# 数据Xnp.random.randn(SEQ_LEN,HIDDEN_DIM).astype(np.float16)Wqnp.random.randn(D_K,HIDDEN_DIM).astype(np.float16)Wknp.random.randn(D_K,HIDDEN_DIM).astype(np.float16)Wvnp.random.randn(D_K,HIDDEN_DIM).astype(np.float16)# 传统t0time.time()out_tradtraditional_attention_fast(X,Wq,Wk,Wv,D_K)t1time.time()# 三线性out_tritrilinear_attention_fast(X,Wq,Wk,Wv,D_K)t2time.time()# 对比print(f\n传统耗时{t1-t0:.2f}s)print(f三线性耗时{t2-t1:.2f}s)print(f 三线性 加速比{(t1-t0)/(t2-t1):.2f}x)# cpu测试结果传统耗时126.50s 三线性耗时796.03s 三线性 加速比0.16xgpuimporttorchimporttime# 检查GPUdevicetorch.device(cudaiftorch.cuda.is_available()elsecpu)print(使用设备:,device)# -----------------------------------------------------------------------------# 传统注意力GPU# -----------------------------------------------------------------------------deftraditional_attention_gpu(X,Wq,Wk,Wv,d_k):QX Wq.T KX Wk.T VX Wv.T score(Q K.T)/torch.sqrt(torch.tensor(d_k,dtypeX.dtype))attntorch.softmax(score,dim-1)outattn Vreturnout# -----------------------------------------------------------------------------# TrilinearCIM 三线性注意力GPU无Q/K/V中间张量# -----------------------------------------------------------------------------deftrilinear_attention_gpu(X,Wq,Wk,Wv,d_k):R1(X Wq.T)/torch.sqrt(torch.tensor(d_k,dtypeX.dtype))scoreR1 Wk X.T attntorch.softmax(score,dim-1)outattn X Wv.Treturnout# -----------------------------------------------------------------------------# 超长篇 高维度 测试# -----------------------------------------------------------------------------if__name____main__:# 长序列 高维度SEQ_LEN10000HIDDEN_DIM2048# BERT-base 维度可改 1024D_K64# 数据放到 GPUXtorch.randn(SEQ_LEN,HIDDEN_DIM,devicedevice,dtypetorch.float16)Wqtorch.randn(D_K,HIDDEN_DIM,devicedevice,dtypetorch.float16)Wktorch.randn(D_K,HIDDEN_DIM,devicedevice,dtypetorch.float16)Wvtorch.randn(D_K,HIDDEN_DIM,devicedevice,dtypetorch.float16)torch.cuda.synchronize()t0time.time()# 传统out_tradtraditional_attention_gpu(X,Wq,Wk,Wv,D_K)torch.cuda.synchronize()t1time.time()# 三线性out_tritrilinear_attention_gpu(X,Wq,Wk,Wv,D_K)torch.cuda.synchronize()t2time.time()# 误差abs_errtorch.abs(out_trad-out_tri).mean().item()rel_errabs_err/torch.abs(out_trad).mean().item()# 输出结果print(f\n 序列长度{SEQ_LEN}| 维度{HIDDEN_DIM}| GPU 对比)print(f传统耗时 :{t1-t0:.3f}s)print(f三线性耗时 :{t2-t1:.3f}s)print(f 加速比 :{((t1-t0)/(t2-t1)):.2f}x)print(f平均绝对误差:{abs_err:.6f})print(f平均相对误差:{rel_err:.2%})序列长度10000|维度2048|GPU 对比 传统耗时:0.001s 三线性耗时:0.010s 加速比:0.13x 平均绝对误差: nan 平均相对误差: nan%