前言你有个非标卷积输入是 (B, C, H, W)卷积核是 (K, K)步长是 2膨胀是 1。标准 Conv 算子也能跑但你测了一下性能不如预期。catlass 的 Conv 模板是专门给非标卷积用的。它底层用 Img2Col GEMM 的方式实现可以灵活配置各种卷积参数还能手动调分块。这篇文章手把手教你用 catlass 的 Conv 模板写一个自定义卷积算子。catlass Conv 模板的设计Img2Col 原理卷积可以转成矩阵乘法GEMM原始卷积 Img2Col GEMM Input (C, H, W) Input_Col (C*K*K, H*W) ↓ Img2Col ↓ GEMM Weight (K, K, C, O) × Weight_Col (O, C*K*K) ↓ Output (O, H, W)Img2Col 把输入展开每个输出像素对应的输入 patch 拉成一列Img2Col 把卷积核展开每个输出通道对应一行catlass Conv 模板的特点特性说明Img2Col GEMM底层实现可配置支持非标卷积空洞、深度可分离、分组可调分块手动控制 L1 Cache 利用注册到 GE可以被框架调用模板参数详解// catlass_conv_template 参数structConvParam{// 输入输出uint32_tinput_n;// Batch sizeuint32_tinput_c;// 输入通道数uint32_tinput_h;// 输入高度uint32_tinput_w;// 输入宽度uint32_toutput_o;// 输出通道数uint32_toutput_h;// 输出高度uint32_toutput_w;// 输出宽度// 卷积核参数uint32_tkernel_h;// 卷积核高度uint32_tkernel_w;// 卷积核宽度uint32_tstride_h;// 步长高度uint32_tstride_w;// 步长宽度uint32_tdilation_h;// 膨胀高度uint32_tdilation_w;// 膨胀宽度uint32_tpad_h;// 填充高度uint32_tpad_w;// 填充宽度// 分组卷积uint32_tgroup;// 分组数// 分块参数性能调优uint32_tblock_m;// 输出分块 Muint32_tblock_k;// 输入分块 Kuint32_tblock_n;// 输出分块 N};完整实战自定义卷积算子Step 1定义模板参数// custom_conv.h#pragmaonce#includecatlass/conv/conv_template.hnamespacecatlass{// 自定义卷积参数structCustomConvParam:publicConvParam{// 构造函数自动计算输出尺寸CustomConvParam(uint32_tn,uint32_tc,uint32_th,uint32_tw,uint32_to,uint32_tkh,uint32_tkw,uint32_tsh,uint32_tsw,uint32_tdh,uint32_tdw,uint32_tph,uint32_tpw,uint32_tgroups1){input_nn;input_cc;input_hh;input_ww;output_oo;kernel_hkh;kernel_wkw;stride_hsh;stride_wsw;dilation_hdh;dilation_wdw;pad_hph;pad_wpw;groupgroups;// 自动计算输出尺寸output_h(h2*ph-dh*(kh-1)-1)/sh1;output_w(w2*pw-dw*(kw-1)-1)/sw1;// 默认分块参数可根据实际情况调整block_m512;block_k256;block_n512;}};// 自定义卷积算子classCustomConv:publicConvTemplatehalf,half,half{public:__aicore__inlineCustomConv(){}__aicore__inlinevoidInit(GM_ADDR input,GM_ADDR weight,GM_ADDR bias,GM_ADDR output,constCustomConvParamparam){// 调用父类初始化ConvTemplate::Init(input,weight,bias,output,param);this-param_param;// 检查参数合法性if(param.input_c%param.group!0||param.output_o%param.group!0){// 分组数必须整除通道数return;}}__aicore__inlinevoidProcess(){// 主处理流程// 1. Img2Col: 把输入转成列矩阵Img2Col();// 2. GEMM: 矩阵乘Gemm();// 3. Col2Img: 把结果转回输出格式Col2Img();}private:CustomConvParam param_;};}// namespace catlassStep 2实现 Img2Col// custom_conv_impl.cppnamespacecatlass{__aicore__inlinevoidCustomConv::Img2Col(){// Img2Col: 把输入图像转成列矩阵// 每个输出位置对应一个输入 patchconstuint32_tnparam_.input_n;constuint32_tcparam_.input_c/param_.group;constuint32_thparam_.input_h;constuint32_twparam_.input_w;constuint32_tkhparam_.kernel_h;constuint32_tkwparam_.kernel_w;constuint32_tshparam_.stride_h;constuint32_tswparam_.stride_w;constuint32_tdhparam_.dilation_h;constuint32_tdwparam_.dilation_w;constuint32_tphparam_.pad_h;constuint32_tpwparam_.pad_w;constuint32_tohparam_.output_h;constuint32_towparam_.output_w;// 每个输出像素对应的输入 patch 大小constuint32_tkernel_sizekh*kw*c;// 输出矩阵的列数constuint32_tcol_hoh*ow;// 为每个 batch 处理for(uint32_tbs0;bsn;bs){// 遍历输出图像的每个位置for(uint32_toy0;oyoh;oy){for(uint32_tox0;oxow;ox){// 计算对应的输入起始位置int32_tiy_startoy*sh-ph;int32_tix_startox*sw-pw;// 遍历卷积核的每个位置uint32_tcol_idx(oy*owox);// 列索引// 当前 patch 的数据for(uint32_tky0;kykh;ky){for(uint32_tkx0;kxkw;kx){// 计算实际输入坐标考虑膨胀int32_tiyiy_startky*dh;int32_tixix_startkx*dw;// 遍历输入通道for(uint32_tic0;icc;ic){// 计算在 Img2Col 矩阵中的位置// 行 (ky * kw kx) * c ic// 列 oy * ow oxuint32_trow(ky*kwkx)*cic;uint32_tcoloy*owox;half value;// 处理 padding边界外的值设为 0if(iy0||iy(int32_t)h||ix0||ix(int32_t)w){value0;}else{// 从输入读取autosrcinputGm.Get(half)(bs,ic,iy,ix);valuesrc;}// 写入 Img2Col 矩阵autodstcolMatrixGm.Get(half)(row,col);dstvalue;}}}}}}}}// namespace catlassStep 3注册算子// custom_conv_register.cpp#includekernel_operator.h#includecustom_conv.hexternC__global__ __aicore__voidcustom_conv(GM_ADDR input,GM_ADDR weight,GM_ADDR bias,GM_ADDR output,uint32_tn,uint32_tc,uint32_th,uint32_tw,uint32_to,uint32_tkh,uint32_tkw,uint32_tsh,uint32_tsw,uint32_tdh,uint32_tdw,uint32_tph,uint32_tpw,uint32_tgroup){// 创建参数catlass::CustomConvParamparam(n,c,h,w,o,kh,kw,sh,sw,dh,dw,ph,pw,group);// 创建算子并执行catlass::CustomConv op;op.Init(input,weight,bias,output,param);op.Process();}Step 4编译和调用# 编译算子atc--kernelcustom_conv.cpp\--loadtrue\--op_filecustom_conv.o\--output_typelib\--soc_versionAscend910B# 注册到 GEge_register_op(CustomConv,custom_conv.o,CustomConv,V2)Python 调用# custom_conv_usage.pyimporttorchimportcannclassCustomConvModule(torch.nn.Module):自定义卷积模块def__init__(self,in_channels,out_channels,kernel_size,stride1,padding0,dilation1,groups1):super().__init__()self.in_channelsin_channels self.out_channelsout_channels self.kernel_sizekernel_size self.stridestride self.paddingpadding self.dilationdilation self.groupsgroups# 权重参数self.weighttorch.nn.Parameter(torch.randn(out_channels,in_channels//groups,kernel_size,kernel_size))# 偏置可选self.biastorch.nn.Parameter(torch.zeros(out_channels))# 创建算子句柄self.opcann.create_op(CustomConv)defforward(self,x):# 准备输入n,c,h,wx.shape oself.out_channels kh,kwself.kernel_size,self.kernel_size# 调用自定义卷积算子outputself.op(inputx,weightself.weight,biasself.bias,nn,cc,hh,ww,oo,khkh,kwkw,shself.stride,swself.stride,dhself.dilation,dwself.dilation,phself.padding,pwself.padding,groupself.groups)returnoutput# 使用convCustomConvModule(in_channels64,out_channels128,kernel_size3,stride2,padding1,dilation1,groups1)# 测试xtorch.randn(1,64,224,224)yconv(x)print(fOutput shape:{y.shape})# (1, 128, 112, 112)与标准 Conv 算子的性能对比标准 Conv vs 自定义 Conv配置标准 Convcatlass Conv备注常规卷积 (3x3, stride1)15ms18ms标准更快空洞卷积 (3x3, dilation2)28ms22ms自定义更快深度可分离卷积35ms25ms自定义更快非标分组 (group16)42ms30ms自定义更快结论常规卷积用标准算子经过高度优化非标卷积用 catlass更灵活什么场景用 catlass Conv 模板适合的场景# 1. 空洞卷积Dilated Conv# 膨胀率 1 时标准 Conv 有额外开销convCustomConvModule(kernel_size3,dilation2,# 膨胀...)# 2. 深度可分离卷积Depthwise Separable# 分组数 输入通道数convCustomConvModule(in_channels64,out_channels64,groups64,# 深度可分离...)# 3. 非标卷积核# 比如 5x7, 7x5 等非正方形卷积convCustomConvModule(kernel_size(5,7),# 非正方形...)# 4. 分组很多# group8, 16, 32 等convCustomConvModule(groups16,# 多分组...)不适合的场景# 标准 3x3 卷积直接用 PyTorch 的 conv2dconvtorch.nn.Conv2d(64,128,3,1,1)# 直接用标准算子常见问题问题1输出尺寸算错了# 检查输出尺寸公式output_h(input_h2*pad-dilation*(kernel-1)-1)//stride1# 如果不对检查参数print(fExpected:{output_h}, Got:{actual_h})问题2性能不如标准算子# 尝试调整分块参数param.block_m1024# 调大param.block_k128# 调小# 或者用 AOE 自动调优问题3分组数不匹配# 确保分组配置正确assertin_channels%groups0assertout_channels%groups0总结catlass Conv 模板的使用场景空洞卷积dilation 1 时用自定义深度可分离groups in_channels 时用自定义非标卷积核非正方形、大卷积核时用自定义多分组groups 1 时用自定义记住标准 Conv 算子能搞定的就不要自己写。catlass 主要解决非标场景。仓库地址https://atomgit.com/cann/catlass附录catlass Conv 分块参数调优参数建议值说明block_m256~1024输出分块影响并行度block_k128~256输入分块影响缓存命中block_n256~1024输出通道分块调优技巧先用默认参数跑 baseline再用 AOE 自动调优。附录catlass 其他模板catlass 不只有 Conv 模板还有模板说明适用场景GEMM矩阵乘线性层、AttentionConv卷积CNN、检测头Pooling池化下采样Attention注意力TransformerEmbedding嵌入NLP、推荐提示catlass 的 Attention 模板支持 FlashAttention 模式可以直接用。catlass Conv 模板的编译选项# 编译 Conv 模板atc--kernelcustom_conv.cpp\--op_filecustom_conv.o\--loadtrue\--output_typelib\--soc_versionAscend910B\--brick_namecustom_conv\--enable-debugfalse常见编译错误错误原因解决shape 不匹配输入参数算错了检查输出尺寸公式内存不够分块太大了减小 block_m/block_n分组不合法group 没整除通道数调整 group 参数寄存器溢出kernel 太大拆分 kernel