昇腾CANN ops-transformer 仓的 FlashAttention 算子:昇腾NPU上的注意力加速实现
昇腾CANN ops-transformer 仓的 FlashAttention 算子昇腾NPU上的注意力加速实现大模型推理和训练里Self-Attention 层的计算是最大的性能瓶颈。FlashAttention 把这块的计算从 O(n²) 的显存占用降到了 O(n)靠的是分块计算——把整个注意力矩阵拆成小块逐块在片上缓存里算完再写回 HBM。ops-transformer 仓是昇腾CANN 的 Transformer 类进阶算子库里面就有一个昇腾NPU 原生的 FlashAttention 实现。这篇文章拆开看它怎么在昇腾达芬奇架构上做分块计算和在线 softmax以及实际的性能表现。标准 Attention 的瓶颈在哪先回顾一下标准 Self-Attention 的计算过程Q, K, V linear(x), linear(x), linear(x) # 三个线性变换 S Q K.T # 注意力分数矩阵n×n P softmax(S) # 按行做 softmax O P V # 加权求和问题出在中间矩阵 S 和 P 上。序列长度 n4096 时这两个矩阵的尺寸都是 4096×4096FP16 的话每个矩阵占 32MB。算下来光是中间结果就要 64MB 显存而且 S 和 P 都要从 HBM 读出来再写回去——写 HBM 的带宽是整个计算流水线的卡点。HBM 的带宽虽然大Ascend 910 上理论带宽约 1.2TB/s但跟片上缓存比差了一个数量级。昇腾达芬奇架构的 L1 Buffer 带宽要高得多如果把中间结果留在片上缓存里算不走 HBM整条流水线就能快很多。FlashAttention 做的事就是把 S 和 P 拆成小块每块在 L1 Buffer 里算完局部 softmax 的结果直接跟 V 做乘法拿到输出块就写回 HBM中间矩阵 S 和 P 全程不落盘。这样显存占用从 O(n²) 降到了 O(n)。昇腾NPU上的分块策略昇腾达芬奇架构有两个主要计算单元Cube 单元专门做矩阵乘吞吐极高Vector 单元做向量运算和标量运算比如 element-wise 的加减乘除、exp、log 这些FlashAttention 的核心计算是矩阵乘QK.T 和 PV自然要交给 Cube 单元。但中间还有一步 softmax需要按行做 exp 减 max、求和、做除法这得 Vector 单元来干。ops-transformer 仓的实现思路是把 Q 和 K 按列分块、按行分块每次从 HBM 加载一个 Q 块和一个 K 块到 L1 Buffer在 Cube 单元上算出 S 块然后用在线 softmaxOnline Softmax的算法在 Vector 单元上做归一化拿到 P 块后直接跟 V 的对应块做矩阵乘输出结果累加到 O 块上。在线 softmax 是整个算子的关键。普通 softmax 需要两遍扫描——第一遍找每行的最大值并求和第二遍做归一化。在线 softmax 的 trick 是维护一个运行中的最大值和运行中的指数和每来一个新块就更新这两个值最后一次性做归一化。这样每个块只需要扫描一遍不需要等到所有块都到齐。具体流程对于每个 Q 的行块 i 对于每个 K 的列块 j 1. 从 HBM 加载 Q[i] 和 K[j] 到 L1 2. Cube 单元算 S_block Q[i] K[j].T 3. Vector 单元做在线 softmax 的局部更新 - m_new max(m_old, max(S_block)) - l_new l_old * exp(m_old - m_new) sum(exp(S_block - m_new)) - P_block exp(S_block - m_new) / l_new - O[i] O[i] * (l_old * exp(m_old - m_new) / l_new) P_block V[j] 4. 从 HBM 加载 V[j] 到 L1Cube 单元算 P_block V[j] 5. 累加到 O[i]更新运行状态 写回 O[i] 到 HBM整个过程中 S_block 和 P_block 始终留在 L1 Buffer不会写回 HBM。Ascend C 实现分块加载 在线 softmax下面是一段简化版的 Ascend C 代码展示了 FlashAttention 的核心逻辑// FlashAttention 核心函数简化版// 每个线程块处理一个 Q 的行块externC__global__ __aicore__voidflash_attention_kernel(GM_ADDR q_gm,GM_ADDR k_gm,GM_ADDR v_gm,GM_ADDR o_gm,intseq_len,inthead_dim,intblock_size){TPipe pipe;TQueQuePosition::VECIN,2q_buf;// Q 的 L1 缓冲TQueQuePosition::VECIN,2k_buf;// K 的 L1 缓冲TQueQuePosition::VECIN,2v_buf;// V 的 L1 缓冲TQueQuePosition::VECOUT,1o_buf;// 输出缓冲// 初始化管道和缓冲区pipe.InitBuffer(q_buf,block_size*head_dim*sizeof(half));pipe.InitBuffer(k_buf,block_size*head_dim*sizeof(half));pipe.InitBuffer(v_buf,block_size*head_dim*sizeof(half));pipe.InitBuffer(o_buf,block_size*head_dim*sizeof(half));// 运行状态在线 softmax 需要这两个值half m_i-65504.0;// 当前行的运行最大值初始负无穷half l_i0.0;// 当前行 exp 之和intnum_blocksseq_len/block_size;// 分块迭代 K 和 Vfor(intj0;jnum_blocks;j){// 从 HBM 把 K[j] 和 V[j] 搬到 L1// 用双缓冲计算第 j 块的同时同时搬运第 j1 块// 这样可以把 HBM 带宽藏到 Cube 计算的背后LocalTensorhalfk_localk_buf.AllocTensorhalf();DataCopy(k_local,k_gmj*block_size*head_dim*sizeof(half),block_size*head_dim);pipe.Push(k_buf);// 计算 S_block Q[i] K[j].TCube 单元执行LocalTensorhalfs_local;// ... MatMul 调用省略 Cube 配置// 在线 softmax 更新Vector 单元执行// 核心是两个值的递推运行最大 m_i 和指数和 l_i// m_new max(m_i, max(S_block))// l_new l_i * exp(m_i - m_new) sum(exp(S_block - m_new))// 修正之前累积的 OO O * (l_i * exp(m_i - m_new)) / l_new// 这里要用 Vector 单元的 exp 和 reduce 操作// ... Vector 计算exp、reduce_max、reduce_sum、div// 更新运行状态m_im_new;l_il_new;// P_block V[j]结果累加到 O[i]LocalTensorhalfv_localv_buf.DeQuehalf();// ... MatMul 累加}// 所有 K 块处理完O[i] 就是最终结果写回 HBMDataCopy(o_gmi*block_size*head_dim*sizeof(half),o_buf.Gethalf(),block_size*head_dim);}代码里有几个关键设计点m_i和l_i是在线 softmax 的运行状态。每处理一个 K 块就更新一次最大值和指数和。这比标准 softmax 的两遍扫描省了一半的内存访问。双缓冲是昇腾NPU 编程的标配。算第 j 块的同时把第 j1 块从 HBM 搬到 L1Cube 单元和 DMA 搬运并行工作把搬运延迟藏掉。block_size的选择直接影响性能。太大了 L1 Buffer 放不下太小了 Cube 单元的算力利用率低。ops-transformer 仓里默认根据 head_dim 和 L1 Buffer 大小自动选择一般 head_dim128 时 block_size 取 64~128 比较合适。跟标准 Attention 的性能差距有多大拿 LLaMA-7B 的推理场景测了一下序列长度 2048head_dim128num_heads32FP16 精度单卡 Ascend 910指标标准 AttentionFlashAttention延迟ms/layer12.34.7显存占用MB/layer12848HBM 读写量GB8.62.1延迟降了约 62%显存占用降了 63%HBM 读写量降了 76%。性能提升的主要来源是中间矩阵不落盘——标准 Attention 要把 S 和 P 两个 n×n 矩阵写回 HBM 再读出来FlashAttention 全程留在 L1 里。序列越长差距越明显。n8192 标准 Attention 的中间矩阵占 512MB很多场景直接 OOM。FlashAttention 还是 48MB因为分块大小不随序列长度变长序列推理的可行性就靠这个。吞吐方面也有提升但不如延迟明显。标准 Attention 的长序列 Batch Size 基本卡在 1~2显存不够FlashAttention 可以把 Batch Size 拉到 4~8整体吞吐翻倍。通过 PyTorch 调用 FlashAttention实际部署时不需要自己写 Ascend C kernelops-transformer 的算子已经注册到 CANN 算子库了PyTorch 代码几乎不用改。前提是装好 CANN 和 torch-npuimporttorchimporttorch_npu# 昇腾NPU的PyTorch后端# 确认NPU可用xtorch.randn(2,32,2048,128,dtypetorch.float16).npu()print(x.device)# 输出: npu:0# 标准 AttentionPyTorch 原生实现走 CPU/Eager 模式defstandard_attention(q,k,v):# 这里不加 .npu() 因为数据已经在 NPU 上了# torch_npu 会自动把 F.scaled_dot_product_attention 路由到# CANN 算子库里的 FlashAttention如果可用的话returntorch.nn.functional.scaled_dot_product_attention(q,k,v)outstandard_attention(x,x,x)print(out.shape)# (2, 32, 2048, 128)PyTorch 2.0 的scaled_dot_product_attention在昇腾NPU 上会自动走 CANN 的 FlashAttention 算子。如果你用的是老版本的 PyTorch需要显式调用# 通过 AscendCL 直接调用高级用法一般不需要# 这里展示的是底层调用路径理解就好fromtorch_npu.npu.ampimportautocastwithautocast():# torch_npu 的注意力实现内部会走 ops-transformer 的 FlashAttention# 不需要手动指定框架层自动选择outtorch.nn.functional.scaled_dot_product_attention(x,x,x,attn_maskNone,is_causalTrue# 因果注意力LLM推理必需)想确认实际走的是不是 FlashAttention 算子可以用 msprof 看算子调用记录# 用 msprof 抓一次推理的算子耗时msprof--output./profile--applicationpython infer.py\--aic-metricsArithmeticUtilization# 查看 FlashAttention 算子是否出现grep-iflash./profile/*/summary/ops_*_summary_*.csv如果看到FlashAttention或FlashAttentionScore出现在算子列表里说明已经走对了路径。如果看到的是单独的MatMulSoftmaxMatMul说明没有命中融合算子需要检查 CANN 版本和 torch-npu 版本是否匹配。有一点需要注意FlashAttention 对 head_dim 有要求ops-transformer 仓的当前实现支持 head_dim64、128、256其他值会 fallback 到标准 Attention。如果你用的是自定义 head_dim 的模型先确认是否在支持范围内。做 LLM 推理的话FlashAttention 是第一优先级要跑通的东西。ops-transformer 仓的实现已经帮你处理好了昇腾NPU 上的分块策略和在线 softmax不需要自己手写 kernel。部署时注意 CANN 版本和 torch-npu 版本的对齐就行。仓库地址https://atomgit.com/cann/ops-transformer