CANN/catlass TLA布局设计
TLA Layouts【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass本文介绍 TLATensor Layout Abstraction中的Layout。如果把 Tensor 看成“逻辑上的多维数组”那么Layout负责回答以下问题一个逻辑坐标(i, j, ...)对应到哪一个线性地址。这块 Tensor 在逻辑上有多大。当底层存在分块、对齐或填充时哪些位置是逻辑有效数据。因此Layout可以理解为“逻辑坐标到内存地址的映射规则”。算法通常依赖这套规则访问数据而不直接依赖底层物理排布。这样同一段计算逻辑就可以适配普通 ND 布局、行优先、列优先以及zN、nZ等分形布局。先建立三个基本概念逻辑坐标 coordcoord表示元素在 Tensor 逻辑空间中的位置约定如下坐标从 0 开始计数。坐标单位是“元素”不是字节也不是 tile 编号。coord的 rank 必须与 Tensor 或 Layout 的逻辑维度一致。即使底层采用zN、nZ这类嵌套布局coord仍然描述逻辑上的行列位置例如(row, col)。例如对一个逻辑形状为(8, 16)的矩阵coord (2, 4)表示第 2 行、第 4 列的元素。它不关心这块数据在内存中是按行连续、按列连续还是按分形块组织。逻辑形状与内存布局在 TLA 中这两个概念被刻意分离逻辑形状从使用者视角看Tensor 有多少行、多少列。内存布局这些逻辑元素在内存中如何排布跨一个维度移动时需要跳过多少位置。Layout的核心价值就是把“逻辑上多大”和“内存里怎样排”同时表达清楚。Tail tile当矩阵尺寸不是 tile 大小的整数倍时边界 tile 往往只包含部分有效元素。这类边界 tile 通常称为 tail tile。TLA 使用originShape表达“逻辑上实际有效的范围”。因此用户通常不需要手工推导每个边缘 tile 的真实尺寸。基础类型TupleTLA 以tla::tuple为基础。它与std::tuple的用途相似都是表达定长元素序列不同之处在于TLA 对模板元编程和高性能场景做了定制。IntTupleIntTuple是 TLA 中最常用的基础概念之一。它可以是一个整数例如int{2}、size_t{16}。一个编译期整数例如Int3{}或别名_3。一个由以上元素递归组成的 tuple例如make_tuple(int{2}, Int3{})。因此IntTuple既可以表示一维尺寸也可以表示带层次结构的嵌套尺寸。常用操作如下rank(IntTuple)返回元素个数。getI(IntTuple)返回第I个元素。depth(IntTuple)返回嵌套层数普通整数的depth为 0。IntTuple不仅用于Layout也用于Shape、Stride等类型定义见include/tla/layout.hpp。Layout 由什么组成Layout本质上由三个IntTuple组成Shape、Stride和OriginShape。字段作用关注点Shape用于内存布局计算的尺寸描述决定布局结构不一定等于逻辑实际尺寸Stride各维度上的步长决定坐标如何映射到线性地址OriginShapeTensor 的逻辑实际尺寸决定哪些元素在逻辑上有效可以先把它们理解成Shape说明“内存按什么结构排”。Stride说明“每跨一步跳多远”。OriginShape说明“逻辑上到底有多少有效数据”。这里最容易混淆的是Shape和OriginShape。两者并不重复Shape面向布局计算允许包含对齐、分块和填充后的结构。OriginShape面向逻辑语义只描述真实有效的数据范围。OriginShape用于把“内存怎样排”与“逻辑上哪些数据有效”区分开。Shape服务于布局计算可能包含对齐、分块或填充后的尺寸。OriginShape服务于逻辑语义描述真实有效的数据范围。例如一个逻辑大小为100 x 100的矩阵采用zN布局时可能出现originShape (100, 100)shape ((16, 7), (16, 7))原因是16 * 7 112说明底层内存按112 x 112的块化结构组织。但逻辑上只有100 x 100是有效元素。这也是 TLA 能自动处理 tail tile 的基础。用户在 block 层和 kernel 层通常只需要按 tile 编程边界有效范围由originShape传递和裁剪无需每一层都手动判断尾块。Layout 的常用接口Layout提供了一组与IntTuple风格一致的访问接口rank(Layout)布局的逻辑维度。getI(Layout)取出第I个分量。depth(Layout)布局的嵌套层数。shape(Layout)返回Shape。stride(Layout)返回Stride。originShape(Layout)返回OriginShape。另外还提供递归版本的辅助接口例如getI0, I1, ..., IN(x)逐层向下取子单元。rankI...(x)查看某个子单元的 rank。depthI...(x)查看某个子单元的 depth。shapeI...(x)查看某个子单元的 shape。originShapeI...(x)查看某个子单元的 origin shape。Layout 构造Layout支持静态整数、动态整数及其混合构造也支持普通矩阵布局和 Ascend 常用内部布局。在昇腾 CUBE 核内部常见内部格式包括zN、nZ、zZ、nN、L0C等在 GEMV、Scale、Bias 等场景中也会使用一维VectorLayout。using namespace tla; // 1. 直接给 shape 和 strideoriginShape 由系统推导 Layout w2xh4 MakeLayout(MakeShape(Int2{}, 4), MakeStride(Int12{}, Int1{})); // 2. 嵌套布局originShape 隐式推导为 (16*2, 16*3) (32, 48) Layout w32xh48 MakeLayout(MakeShape(MakeShape(16, 2), MakeShape(16, 3)), MakeStride(MakeStride(16, 256), MakeStride(1, 512))); // 3. 显式指定 originShape Layout w2xh4_explicit MakeLayout(MakeShape(Int2{}, 4), MakeStride(Int12{}, Int1{}), MakeShape(2, 4)); Layout w32xh48_explicit MakeLayout(MakeShape(MakeShape(16, 2), MakeShape(16, 3)), MakeStride(MakeStride(16, 256), MakeStride(1, 512)), MakeShape(32, 48)); // 4. rank2 时也可以用 LayoutTag (rows, cols) 构造 auto rm MakeLayoutfloat, Catlass::layout::RowMajor(2, 4); // 5. 一维 VectorLayout auto vec MakeLayout(128);其中MakeLayout返回Layout。MakeShape返回Shape。MakeStride返回Stride。上面的布局可写成w2xh4 : (_2, 4):(_12, _1) w32xh48 : ((16, 2), (16, 3)):((16, 256), (1, 512))读法如下前一部分是Shape。后一部分是Stride。如果省略OriginShape表示它可由Shape推导或与逻辑尺寸一致。从直观例子理解 Shape 与 Stride2x3 行优先shape (2, 3) stride (3, 1)含义是行维度前进一步线性地址增加 3。列维度前进一步线性地址增加 1。因此线性地址顺序为逻辑坐标线性地址(0, 0)0(0, 1)1(0, 2)2(1, 0)3(1, 1)4(1, 2)52x3 列优先shape (2, 3) stride (1, 2)含义是行维度前进一步线性地址增加 1。列维度前进一步线性地址增加 2。因此线性地址顺序为逻辑坐标线性地址(0, 0)0(1, 0)1(0, 1)2(1, 1)3(0, 2)4(1, 2)5以zN为例理解嵌套布局示例布局shape ((4, 2), (4, 3)) stride ((4, 16), (1, 32))可以理解为行方向先以 4 为一个内层块再沿行方向重复 2 次。列方向先以 4 为一个内层块再沿列方向重复 3 次。子块内部如何走、子块之间如何跳分别由嵌套Stride给出。关键点不在于记住每个数字而在于理解TLA 用嵌套Shape和Stride显式表达分块布局的结构层次而不是把这类格式硬编码进算法。坐标如何映射为索引在 TLA 中可以使用tla::crd2offset(coord, shape, stride)将逻辑坐标转换为线性索引。约束如下coord、shape、stride的 rank 必须一致。coord表示逻辑元素坐标而不是字节偏移。auto shape ShapeShape_4, _2, Shape_4, _3{}; auto stride StrideStride_4, _16, Stride_1, _32{}; print(crd2offset(tla::MakeCoord(1, 5), shape, stride)); // 37这段代码表示在一个逻辑大小为(8, 12)、底层按分形格式排布的矩阵中逻辑坐标(1, 5)对应的线性索引为37。获取 TileLayoutTileLayout 可以通过GetTileLayout获取template class Layout, class TileShape, class Coord auto GetTileLayout(Layout const layout, TileShape const tileShape, Coord const coord); using namespace tla; Layout a LayoutShapeShape_4, _2, Shape_4, _3, StrideStride_4, _16, Stride_1, _32, Shape_8, _12{}; Layout a0 GetTileLayout(a, MakeShape(4, 4), MakeCoord(6, 10)); // 结果可理解为stride 保持不变逻辑有效范围裁剪为 (2, 2)参数语义如下tileShape期望取出的 tile 大小单位是元素。coordtile 左上角在父 layout 逻辑空间中的元素坐标单位也是元素。也就是说coord (6, 10)的含义是“从逻辑第 6 行、第 10 列开始取 tile”而不是“第 6 个 tile、第 10 个 tile”。GetTileLayout的核心语义GetTileLayout返回的是一个 tile 视图的Layout不会改变底层数据排布。它主要做三件事保留原有stride()因为底层内存布局没有变化。用tileShape构造 tile 的shape()当父布局带有嵌套结构时返回结果会在需要时保持同样的结构层次。根据父 layout 的originShape()和起始coord计算 tile 的originShape()。其中第 3 步最关键$$ origin_shape[d] \min(tileShape[d], \max(origin_base[d] - coord[d], 0)) $$它表示“从当前位置开始在逻辑上还剩多少有效元素”。因此中间区域的 tileoriginShape tileShape。触边的 tail tileoriginShape会自动缩小。“按父 layout 的结构转换成对应的shape()”是什么意思这句话的含义是当父布局本身是嵌套布局时tile 的shape()也需要保持同样的结构层次这样后续访问规则才能继续复用。例如父布局的行和列都按16为内层块组织parent shape ((16, 7), (16, 7)) parent originShape (100, 100)如果希望取一个逻辑大小为(32, 48)的 tile那么这个 tile 的逻辑尺寸可以直接写成(32, 48)但在父布局是zN的前提下它对应的shape()会按父布局的结构表达成tile logical size (32, 48) tile shape ((16, 2), (16, 3))这里发生的是“结构转换”不是“重新排布数据”逻辑上tile 仍然是32 x 48。布局上它被表达成“每维一个 16 的内层块再乘以外层块个数”。stride()仍继承自父布局因此访问规则不变。这样做的目的是保证父 layout 和 tile layout 在结构层次上保持一致。参数约束tileShape与coord都必须是一层 tuple即depth 1。rank(coord) rank(tileShape)。不同布局下的行为如果父 layout 是普通 vector 或 matrix返回 layout 的shape()通常就等于tileShape。如果父 layout 是嵌套或分形布局例如zN、nZ、zZ、L0C当前实现仅支持rank 2并会把(rows, cols)形式的tileShape转换成与父布局同结构的嵌套Shape。【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考