Scatter 算子 API 描述【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench1. 算子简介将 updates 按索引 indices 更新到 data 中。主要应用场景嵌入表更新与稀疏梯度回传One-hot 编码生成图神经网络中的消息聚合scatter_add稀疏张量的构造与更新算子特征难度等级L2ScatterUpdate三输入单输出按指定维度将 updates 的值写入到 data 对应索引位置2. 算子定义数学公式对于 3D 张量当 dim0 时$$ y[\text{index}[i][j][k]][j][k] \text{src}[i][j][k] $$更一般地对于任意维度 dim$$ y[\text{index}_0][\text{index}_1] \cdots [\text{index}_{\text{dim}}] \cdots [\text{index}_{n-1}] \text{updates}[i_0][i_1] \cdots [i_{n-1}] $$其中 $\text{index}_d \text{indices}[i_0][i_1] \cdots [i_{n-1}]$ 当 $d \text{dim}$否则 $\text{index}_d i_d$。当指定 reduce 参数时add$y[\ldots] y[\ldots] \text{updates}[\ldots]$multiply$y[\ldots] y[\ldots] \times \text{updates}[\ldots]$amax$y[\ldots] \max(y[\ldots], \text{updates}[\ldots])$amin$y[\ldots] \min(y[\ldots], \text{updates}[\ldots])$3. 接口规范算子原型cann_bench.scatter(Tensor data, int dim, Tensor indices, Tensor updates, str? reduceNone) - Tensor y输入参数说明参数类型默认值描述dataTensor必选输入数据张量dimint必选沿哪个维度 scatterindicesTensor必选索引张量值必须在 [0, data.size(dim)) 范围内updatesTensor必选更新值张量与 data 维度数相同reducestrNone聚合方式可选值None(update), add, multiply, amin, amax输出参数Shapedtype描述y与 data 相同与 data 相同输出张量scatter 结果与 data 形状相同数据类型data dtypeindices dtypeupdates dtype输出 dtypefloat16int32 / int64float16float16float32int32 / int64float32float32bfloat16int32 / int64bfloat16bfloat16int32int32 / int64int32int32int64int32 / int64int64int64规则与约束data、indices、updates 的维度数必须相同indices 的每个维度大小不能超过对应 data 或 updates 的维度大小indices 中的值必须在 [0, data.size(dim)) 范围内updates 和 data 的 dtype 必须一致indices 的 dtype 必须为 int32 或 int64reduce 为 None 时执行直接覆盖更新为 add 时执行累加为 multiply 时执行累乘为 amax/amin 时取最大/最小值输出 shape 与 data shape 完全一致支持范围输入 tensor 各维度与参数的支持范围维度 / 参数范围备注data维度数1 ~ 8cases.csv 实测 1 ~ 5data、indices、updates维度数必须相同data各维度大小1 ~ 2097152cases.csv 实测 2 ~ 1048583一维大张量场景indices各维度大小1 ~ 2097152cases.csv 实测 2 ~ 8193每维 ≤ 对应data维度大小updates各维度大小1 ~ 2097152cases.csv 实测 2 ~ 8193shape 须与indices一致indices值[0, data.size(dim))cases.csv 实测覆盖完整索引范围dim0 ~ 7cases.csv 实测 0 / 1支持负数索引等价范围为[-rank, rank-1]reduceNone/add/multiply/amin/amaxcases.csv 实测全部 5 种取值4. 精度要求采用生态算子精度标准进行验证。误差指标平均相对误差MERE采样点中相对误差平均值$$ \text{MERE} \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$最大相对误差MARE采样点中相对误差最大值$$ \text{MARE} \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$通过标准数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2当平均相对误差 MERE Threshold最大相对误差 MARE 10 * Threshold 时判定为通过。5. 标准 Golden 代码import torch Scatter算子Torch Golden参考实现 将updates按索引indices更新到data中 公式: y[i] updates[j] where indices[j] i def scatter( data: torch.Tensor, dim: int, indices: torch.Tensor, updates: torch.Tensor, reduce: str None ) - torch.Tensor: 将updates按索引indices更新到data中 公式: y[i] updates[j] where indices[j] i Args: data: 输入数据张量 dim: 沿哪个维度scatter indices: 索引张量 updates: 更新值张量 reduce: 聚合方式 Returns: 输出张量scatter结果 y data.clone() if reduce is None or reduce update: y.scatter_(dim, indices.long(), updates) elif reduce add: y.scatter_add_(dim, indices.long(), updates) return y6. 额外信息算子调用示例import torch import cann_bench data torch.randn(1024, 1024, dtypetorch.float16, devicenpu) indices torch.randint(0, 1024, (1024, 512), dtypetorch.int32, devicenpu) updates torch.randn(1024, 512, dtypetorch.float16, devicenpu) y cann_bench.scatter(data, 1, indices, updates) # dim1, 直接更新 # reduceadd 模式 data torch.randn(2048, 512, dtypetorch.float32, devicenpu) indices torch.randint(0, 2048, (1024, 512), dtypetorch.int32, devicenpu) updates torch.randn(1024, 512, dtypetorch.float32, devicenpu) y cann_bench.scatter(data, 0, indices, updates, reduceadd)【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考