aclnnQuantGroupedMatMulAlltoAllvV2【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer 查看源码产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品/Atlas A3 推理系列产品×Atlas A2 训练系列产品/Atlas A2 推理系列产品×Atlas 200I/500 A2 推理产品×Atlas 推理系列产品×Atlas 训练系列产品×功能说明接口功能完成量化的路由专家GroupedMatMul、Unpermute、AlltoAllv融合并实现与共享专家MatMul并行融合先计算后通信支持Pertensor-Pertensor、Mx量化模式。计算公式路由专家 $$ gmmY (gmmX gmmWeight) * gmmXScale * gmmWeightScale \ unpermuteOut Unpermute(gmmY) \ y AlltoAllv(unpermuteOut) $$共享专家$$ mmY (mmX mmWeight) * mmXScaleOptional * mmWeightScaleOptional $$ 相较于aclnnQuantGroupedMatMulAlltoAllv接口该接口变更如下新增commMode参数用户根据该参数指定芯片使用的通信引擎。Ascend 950PR/Ascend 950DT 支持空字符串、ai_cpu和ccu。指定空字符串时根据卡数调用通信引擎卡数小于等于8时调用CCU引擎否则调用AI_CPU引擎。函数原型该算子分为两段式接口必须先调用aclnnQuantGroupedMatMulAlltoAllvV2GetWorkspaceSize接口获取入参并根据计算流程计算所需workspace大小以及包含了算子计算流程的执行器再调用aclnnQuantGroupedMatMulAlltoAllvV2接口执行计算。aclnnStatus aclnnQuantGroupedMatMulAlltoAllvV2GetWorkspaceSize( const aclTensor* gmmX, const aclTensor* gmmWeight, const aclTensor* gmmXScale, const aclTensor* gmmWeightScale, const aclTensor* sendCountsTensorOptional, const aclTensor* recvCountsTensorOptional, const aclTensor* mmXOptional, const aclTensor* mmWeightOptional, const aclTensor* mmXScaleOptional, const aclTensor* mmWeightScaleOptional, const aclTensor* commQuantScaleOptional, int64_t gmmXQuantMode, int64_t gmmWeightQuantMode, int64_t mmXQuantMode, int64_t mmWeightQuantMode, int64_t commQuantMode, int64_t commQuantDtypeOptional, int64_t groupSize, const char* group, const char* commMode, int64_t epWorldSize, const aclIntArray* sendCounts, const aclIntArray* recvCounts, bool transGmmWeight, bool transMmWeight, aclTensor* y, aclTensor* mmYOptional, uint64_t* workspaceSize, aclOpExecutor** executor)aclnnStatus aclnnQuantGroupedMatMulAlltoAllvV2( void* workspace, uint64_t workspaceSize, aclOpExecutor* executor, aclrtStream stream)aclnnQuantGroupedMatMulAlltoAllvV2GetWorkspaceSize参数说明参数名输入/输出描述使用说明数据类型数据格式维度(shape)非连续TensorgmmX输入公式中的输入 gmmX。shape (A, H1)。HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1ND2xgmmWeight输入公式中的输入 gmmWeight。shape (e, H1, N1)。e 为每卡部署的专家数H1 为 hidden sizeN1 为路由专家 FFN 中间维度。HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1ND3xgmmXScale输入gmmX 的量化系数。pertensor量化shape (1)。mx量化shape (A, ceil(H1/64), 2)FLOAT32、FLOAT8_E8M0NDpertensor量化1mx量化3xgmmWeightScale输入gmmWeight 的量化系数。pertensor量化shape (1)。mx量化shape (e, N1, ceil(H1/64), 2)weight转置时为(e, ceil(H1/64), N1, 2)FLOAT32、FLOAT8_E8M0NDpertensor量化1mx量化3xsendCountsTensorOptional输入AlltoAllv 使用的 send count。当前仅支持空。shape (e * ep, )。e 为每卡部署的专家个数ep 为 ep 域大小。INT64ND1xrecvCountsTensorOptional输入AlltoAllv 使用的 recv count。当前仅支持空。shape (e * ep, )。e 为每卡部署的专家个数ep 为 ep 域大小。INT64ND1xmmXOptional输入公式中的输入 mmX。shape (bs, H2)。bs 为每卡部署的专家个数H2 为 hidden size。HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1ND2xmmWeightOptional输入公式中的输入 mmWeight。shape (H2, N2)。H2 为 hidden sizeN2 为共享专家 FFN 的中间层维度。HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2、FLOAT4_E2M1ND2xmmXScaleOptional输入mmX 的量化系数。pertensor量化shape (1)。mx量化shape (BS, ceil(H2/64), 2)FLOAT32、FLOAT8_E8M0NDpertensor量化1mx量化3xmmWeightScaleOptional输入mmWeight 的量化系数。pertensor量化shape(1)。mx量化: shape (N2, ceil(H2/64), 2)weight转置时为(ceil(H2/64), N2, 2)FLOAT32、FLOAT8_E8M0NDpertensor量化1mx量化3xcommQuantScaleOptional输入低比特通信量化系数。预留参数当前仅支持空。FLOAT32ND1xgmmXQuantMode输入gmmX 的量化模式。必须传入量化模式当前支持 1 pertensor量化和 6mx量化。INT64-1xgmmWeightQuantMode输入gmmWeight 的量化模式。必须传入量化模式当前支持 1 pertensor量化和 6mx量化。INT64-1xmmXQuantMode输入mmX 的量化模式。mmX 非空则必须传入量化模式当前支持 1 pertensor量化和 6mx量化。INT64-1xmmWeightQuantMode输入mmWeight 的量化模式。mmWeight 不为空则必须传入量化模式当前支持 1 pertensor量化和 6mx量化。INT64-1xcommQuantMode输入低比特通信量化模式。当前低比特功能预留必须传入 0表示不量化。INT64-1xcommQuantDtypeOptional输入低比特通信的数据类型。当前低比特功能预留必须传入 -1。INT64-1xgroupSize输入PerGroup 量化分组大小。用于 Matmul 计算三个方向上的量化分组大小预留参数仅支持配置为 0取值不生效。groupSize 输入由 3 个方向的 groupSizeMgroupSizeNgroupSizeK 三个值拼接组成每个值占 16 位共占用 int64_t 类型 groupSize 的低 48 位高 16 位无效计算公式为groupSize groupSizeK | groupSizeN 16 | groupSizeM 32。INT64---group输入通信域标识。字符串长度需大于 0小于 128。char*---commMode输入指定当前通信类型。支持输入、ai_cpu和ccu。char*---epWorldSize输入通信域大小。支持 2/4/8/16/32/64/128/256。INT64---sendCounts输入AlltoAllv 使用的 send count。表示其他Rank向当前rank上各expert发送的token数量。支持的维度为 e * ep。按sendCounts[fromRank][expertId]一维展开, 例如e3时顺序为e0,e1,e2,e0,e1,e2, ...aclIntArray*元素类型 INT64ND--recvCounts输入AlltoAllv 使用的 recv count。表示AlltoAllv后本卡需要接收到的token数量。支持的维度为 e * ep。按recvCounts[fromRank][expertId]一维展开, 例如e3时顺序为e0,e1,e2,e0,e1,e2, ...aclIntArray*元素类型 INT64ND--transGmmWeight输入gmm 的右矩阵是否转置。必须传入无默认值。BOOLND--transMmWeight输入mm 的右矩阵是否转置。必须传入无默认值。BOOLND--y输出grouped matmul 计算输出。不支持空 Tensor。shape (BSK, N1)。FLOAT16、BFLOAT16ND2xmmYOptional输出matmul 计算输出。shape (bs, N1)。FLOAT16、BFLOAT16ND2xworkspaceSize输出返回需要在 Device 侧申请的 workspace 大小。-UINT64ND--executor输出返回 op 执行器包含了算子计算流程。-aclOpExecutor*ND--gmmXQuantMode、gmmWeightQuantMode、mmXQuantMode、mmWeightQuantMode、commQuantMode的枚举值跟量化模式关系如下:0: 非量化1: pertensor2: perchannel3: pertoken4: pergroup5: perblock6: mx量化7: pertoken动态量化返回值返回aclnnStatus状态码具体参见aclnn返回码。第一阶段接口完成入参校验出现以下场景报错返回值错误码描述ACLNN_ERR_PARAM_NULLPTR161001输入和输出的必选参数Tensor是空指针。ACLNN_ERR_PARAM_INVALID161002输入和输出的数据类型不在支持的范围内。aclnnQuantGroupedMatMulAlltoAllvV2参数说明参数名输入/输出描述workspace输入在Device侧申请的workspace内存地址。workspaceSize输入在Device侧申请的workspace大小由第一段接口aclnnQuantGroupedMatMulAlltoAllvV2GetWorkspaceSize获取。executor输入op执行器包含了算子计算流程。stream输入指定执行任务的Stream。返回值返回aclnnStatus状态码具体参见aclnn返回码。约束说明确定性计算aclnnQuantGroupedMatMulAlltoAllvV2默认确定性实现。通信引擎约束Ascend 950PR/Ascend 950DT支持CCU通信。e * epWorldSize最大支持256e表示单卡上的专家数量最大支持到32epWorldSize支持2/4/8/16/32/64/128/256;gmmX的shape(A, H1)A为sendCounts之和H1取值范围(0, 65536);gmmWeight的shape(e, H1, N1)N1取值范围(0, 65536);y的shape为(BSK, N1)第一维其中K的范围[2, 8]BSK为recvCounts之和;mmX是共享专家的左矩阵shape为(BS, H2)H2的取值范围(0, 12288]mmWeight是共享专家的右矩阵shape为(H2 N2)N2的取值范围(0, 65536)sendCounts为发送到其他卡的token数数组大小为e * epWorldSize;recvCounts从其他卡的token数数组大小为e * epWorldSize;路由专家和共享专家量化Scale、Mode等均为必选低比特通信Mode为必选参数DType和Scale为可选当Mode为非0时需要提供DType和Scale参数说明里shape使用的变量BSK本卡接收的token数是recvCounts参数累加之和取值范围(0, 52428800)。H1表示路由专家hidden size隐藏层大小取值范围(0, 65536)。H2表示共享专家hidden size隐藏层大小取值范围(0, 12288]。e表示单卡上专家个数0e32e * epWorldSize最大支持256。N1表示路由专家 FFN 的中间层维度取值范围(0, 65536)。N2表示共享专家 FFN 的中间层维度取值范围(0, 65536)。BSbatch sequence size。K表示选取TopK个专家K的范围[2, 8]。A本卡发送的token数是sendCounts参数累加之和。ep通信域内所有卡的 A 参数的累加和等于所有卡上的 BSK 参数的累加和。mx量化且gmmX与gmmWeight为FLOAT4_E2M1时H1和H2必须为偶数且不能为2同时transGmmWeight和transMmWeight为false情况下N1和N2必须为偶数。gmmWeight和gmmWeightScale的转置状态必须保持一致同时转置或同时不转置。mmWeight和mmWeightScale同样需要保持转置状态一致。groupSize:仅当gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入都是2维及以上数据时groupSize取值有效其他场景需传入0。groupSize值支持公式推导传入的groupSize内部会按如下公式分解得到groupSizeM、groupSizeN、groupSizeK当其中有1个或多个为0会根据gmmX/gmmWeight/mmX/mmWeight/gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入shape重新设置groupSizeM、groupSizeN、groupSizeK用于计算。设置原理如果groupSizeM0表示m方向量化分组值由接口推导推导公式为groupSizeM m / scaleM需保证m能被scaleM整除其中m与gmmX/mmX shape中的m一致scaleM与gmmXScale/mmXScale shape中的m一致如果groupSizeK0表示k方向量化分组值由接口推导推导公式为groupSizeK k / scaleK需保证k能被scaleK整除其中k与gmmX/mmX shape中的k一致scaleK与gmmXScale/mmXScale shape中的k一致如果groupSizeN0表示n方向量化分组值由接口推导推导公式为groupSizeN n / scaleN需保证n能被scaleN整除其中n与gmmWeight/mmWeight shape中的n一致scaleN与gmmWeightScale/mmWeightScale shape中的n一致。 $$ groupSize groupSizeK | groupSizeN 16 | groupSizeM 32 $$如果满足重新设置条件当gmmXScale/gmmWeightScale/mmXScale/mmWeightScale输入是2维及以上时且数据类型都为FLOAT8_E8M0时[groupSizeMgroupSizeNgroupSizeK]取值组合会推导为[1, 1, 32]对应groupSize的值为4295032864。量化参数约束当前版本支持pertensor量化、mx量化。类型约束pertensor量化gmmXgmmWeightgmmXScalegmmWeightScalemmXScalemmWeightScaleyHIFLOAT8HIFLOAT8FLOAT32FLOAT32FLOAT32FLOAT32FLOAT16/BFLOAT16mx量化gmmXgmmWeightgmmXScalegmmWeightScalemmXScalemmWeightScaleyFLOAT8_E4M3FNFLOAT8_E4M3FNFLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT16/BFLOAT16FLOAT8_E4M3FNFLOAT8_E5M2FLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT16/BFLOAT16FLOAT8_E5M2FLOAT8_E4M3FNFLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT16/BFLOAT16FLOAT8_E5M2FLOAT8_E5M2FLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT16/BFLOAT16FLOAT4_E2M1FLOAT4_E2M1FLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT8_E8M0FLOAT16/BFLOAT16mmX类型与gmmX类型保持一致mmWeight类型与gmmWeight类型保持一致mmY类型与y类型保持一致。调用示例示例代码如下仅供参考具体编译和执行过程请参考编译与运行样例。说明本示例代码调用了部分HCCL集合通信库接口HcclGetCommName、HcclCommInitAll、HcclCommDestroy, 请参考HCCL API (C)。Ascend 950PR/Ascend 950DT #include thread #include iostream #include string #include vector #include acl/acl.h #include hccl/hccl.h #include aclnnop/aclnn_quant_grouped_mat_mul_allto_allv_v2.h #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while (0) int64_t GetShapeSize(const std::vectorint64_t shape) { int64_t shape_size 1; for (auto i : shape) { shape_size * i; } return shape_size; } template typename T int CreateAclTensor(const std::vectorT hostData, const std::vectorint64_t shape, void **deviceAddr, aclDataType dataType, aclTensor **tensor) { auto size GetShapeSize(shape) * sizeof(T); auto ret aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclrtMalloc failed. ret: %d\n, ret); return ret); ret aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclrtMemcpy failed. ret: %d\n, ret); return ret); std::vectorint64_t strides(shape.size(), 1); for (int64_t i shape.size() - 2; i 0; i--) { strides[i] shape[i 1] * strides[i 1]; } *tensor aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } struct Args { int rankId; HcclComm hcclComm; aclrtStream stream; aclrtContext context; }; // shape 基本信息 constexpr int64_t EP_WORLD_SIZE 8; constexpr int64_t BS 4096; constexpr int64_t K 2; constexpr int64_t H 7168; constexpr int64_t e 4; constexpr int64_t N1 4096; constexpr int64_t N2 4096; constexpr int64_t A BS * K; std::vectorint16_t pGmmyData(BS *K *N1, 0); std::vectorint16_t pmmYData(BS *N2, 0); int LaunchOneThreadAlltoAllvGmm(Args args) { int ret aclrtSetCurrentContext(args.context); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclrtSetCurrentContext failed. ret: %d\n, ret); return ret); char hcomName[128] {0}; ret HcclGetCommName(args.hcclComm, hcomName); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] HcclGetEpCommName failed. ret: %d\n, ret); return -1); std::vectorint64_t gmmXShape {A, H}; std::vectorint64_t gmmWShape {e, H, N1}; std::vectorint64_t gmmXScaleShape {1}; std::vectorint64_t gmmWScaleShape {1}; std::vectorint64_t yShape {BS * K, N1}; std::vectorint64_t mmXShape {BS, H}; std::vectorint64_t mmWShape {H, N2}; std::vectorint64_t mmXScaleShape {1}; std::vectorint64_t mmWScaleShape {1}; std::vectorint64_t mmYShape {BS, N2}; std::vectorint64_t sendCountsShape {EP_WORLD_SIZE * e}; std::vectorint64_t recvCountsShape {EP_WORLD_SIZE * e}; std::vectorint64_t sendCountsList(EP_WORLD_SIZE * e, A / (EP_WORLD_SIZE * e)); std::vectorint64_t recvCountsList(EP_WORLD_SIZE * e, A / (EP_WORLD_SIZE * e)); void *gmmXDeviceAddr nullptr; void *gmmWDeviceAddr nullptr; void *gmmXScaleDeviceAddr nullptr; void *gmmWScaleDeviceAddr nullptr; void *yDeviceAddr nullptr; void *mmXDeviceAddr nullptr; void *mmWDeviceAddr nullptr; void *mmXScaleDeviceAddr nullptr; void *mmWScaleDeviceAddr nullptr; void *mmYDeviceAddr nullptr; aclTensor *gmmX nullptr; aclTensor *gmmW nullptr; aclTensor *gmmXScale nullptr; aclTensor *gmmWScale nullptr; aclTensor *y nullptr; aclTensor *mmX nullptr; aclTensor *mmW nullptr; aclTensor *mmXScale nullptr; aclTensor *mmWScale nullptr; aclTensor *mmY nullptr; aclTensor *sendCountsTensor nullptr; aclTensor *recvCountsTensor nullptr; aclTensor *commQuantScaleOptional nullptr; int64_t gmmXQuantMode 1; int64_t gmmWQuantMode 1; int64_t mmXQuantMode 1; int64_t mmWQuantMode 1; int64_t commQuantMode 0; int64_t commQuantDtypeOptional -1; int64_t groupSize 0; uint64_t workspaceSize 0; aclOpExecutor *executor nullptr; void *workspaceAddr nullptr; long long gmmXShapeSize GetShapeSize(gmmXShape); long long gmmWShapeSize GetShapeSize(gmmWShape); long long gmmXScaleShapeSize GetShapeSize(gmmXScaleShape); long long gmmWScaleShapeSize GetShapeSize(gmmWScaleShape); long long yShapeSize GetShapeSize(yShape); long long mmXShapeSize GetShapeSize(mmXShape); long long mmWShapeSize GetShapeSize(mmWShape); long long mmXScaleShapeSize GetShapeSize(mmXScaleShape); long long mmWScaleShapeSize GetShapeSize(mmWScaleShape); long long mmYShapeSize GetShapeSize(mmYShape); std::vectoruint8_t gmmXHostData(gmmXShapeSize, (args.rankId 1) * 1024); // HIFLOAT8 std::vectoruint8_t gmmWHostData(gmmWShapeSize, (args.rankId 1) * 512); std::vectorfloat gmmXScaleHostData(gmmXScaleShapeSize, 1); std::vectorfloat gmmWScaleHostData(gmmWScaleShapeSize, 1); std::vectorint16_t yHostData(yShapeSize, 65535); std::vectoruint8_t mmXHostData(mmXShapeSize, (args.rankId 1) * 1024); // HIFLOAT8 std::vectoruint8_t mmWHostData(mmWShapeSize, (args.rankId 1) * 512); std::vectorfloat mmXScaleHostData(mmXScaleShapeSize, 1); std::vectorfloat mmWScaleHostData(mmWScaleShapeSize, 1); std::vectorint16_t mmYHostData(mmYShapeSize, 0); ret CreateAclTensor(gmmXHostData, gmmXShape, gmmXDeviceAddr, aclDataType::ACL_HIFLOAT8, gmmX); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(gmmWHostData, gmmWShape, gmmWDeviceAddr, aclDataType::ACL_HIFLOAT8, gmmW); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(gmmXScaleHostData, gmmXScaleShape, gmmXScaleDeviceAddr, aclDataType::ACL_FLOAT, gmmXScale); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(gmmWScaleHostData, gmmWScaleShape, gmmWScaleDeviceAddr, aclDataType::ACL_FLOAT, gmmWScale); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(yHostData, yShape, yDeviceAddr, aclDataType::ACL_FLOAT16, y); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(mmXHostData, mmXShape, mmXDeviceAddr, aclDataType::ACL_HIFLOAT8, mmX); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(mmWHostData, mmWShape, mmWDeviceAddr, aclDataType::ACL_HIFLOAT8, mmW); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(mmXScaleHostData, mmXScaleShape, mmXScaleDeviceAddr, aclDataType::ACL_FLOAT, mmXScale); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(mmWScaleHostData, mmWScaleShape, mmWScaleDeviceAddr, aclDataType::ACL_FLOAT, mmWScale); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(mmYHostData, mmYShape, mmYDeviceAddr, aclDataType::ACL_FLOAT16, mmY); CHECK_RET(ret ACL_SUCCESS, return ret); aclIntArray *sendCounts aclCreateIntArray(sendCountsList.data(), sendCountsList.size()); aclIntArray *recvCounts aclCreateIntArray(recvCountsList.data(), recvCountsList.size()); // 调用第一阶段接口 ret aclnnQuantGroupedMatMulAlltoAllvV2GetWorkspaceSize( gmmX, gmmW, gmmXScale, gmmWScale, sendCountsTensor, recvCountsTensor, mmX, mmW, mmXScale, mmWScale, commQuantScaleOptional, gmmXQuantMode, gmmWQuantMode, mmXQuantMode, mmWQuantMode, commQuantMode, commQuantDtypeOptional, groupSize, hcomName, ccu, EP_WORLD_SIZE, sendCounts, recvCounts, false, false, y, mmY, workspaceSize, executor); CHECK_RET( ret ACL_SUCCESS, LOG_PRINT([ERROR] aclnnQuantGroupedMatMulAlltoAllvV2GetWorkspaceSize failed. ret %d \n, ret); return ret); if (workspaceSize 0) { ret aclrtMalloc(workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclrtMalloc workspace failed. ret %d \n, ret); return ret); } // 调用第二阶段接口 ret aclnnQuantGroupedMatMulAlltoAllvV2(workspaceAddr, workspaceSize, executor, args.stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclnnQuantGroupedMatMulAlltoAllvV2 failed. ret %d \n, ret); return ret); // 固定写法同步等待任务执行结束 ret aclrtSynchronizeStreamWithTimeout(args.stream, 10000000); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclrtSynchronizeStreamWithTimeout failed. ret %d \n, ret); return ret); LOG_PRINT([INFO] device_%d aclnnQuantGroupedMatMulAlltoAllvV2 execute successfully.\n, args.rankId); // 释放device资源需要根据具体API的接口定义修改 if (args.rankId 0) { size_t size A * N1 * sizeof(int16_t); aclrtMemcpy(pGmmyData.data(), size, yDeviceAddr, size, ACL_MEMCPY_DEVICE_TO_HOST); } if (gmmX ! nullptr) { aclDestroyTensor(gmmX); } if (gmmW ! nullptr) { aclDestroyTensor(gmmW); } if (gmmXScale ! nullptr) { aclDestroyTensor(gmmXScale); } if (gmmWScale ! nullptr) { aclDestroyTensor(gmmWScale); } if (y ! nullptr) { aclDestroyTensor(y); } if (mmX ! nullptr) { aclDestroyTensor(mmX); } if (mmW ! nullptr) { aclDestroyTensor(mmW); } if (mmXScale ! nullptr) { aclDestroyTensor(mmXScale); } if (mmWScale ! nullptr) { aclDestroyTensor(mmWScale); } if (mmY ! nullptr) { aclDestroyTensor(mmY); } if (gmmXDeviceAddr ! nullptr) { aclrtFree(gmmXDeviceAddr); } if (gmmWDeviceAddr ! nullptr) { aclrtFree(gmmWDeviceAddr); } if (gmmXScaleDeviceAddr ! nullptr) { aclrtFree(gmmXScaleDeviceAddr); } if (gmmWScaleDeviceAddr ! nullptr) { aclrtFree(gmmWScaleDeviceAddr); } if (yDeviceAddr ! nullptr) { aclrtFree(yDeviceAddr); } if (mmXDeviceAddr ! nullptr) { aclrtFree(mmXDeviceAddr); } if (mmWDeviceAddr ! nullptr) { aclrtFree(mmWDeviceAddr); } if (mmXScaleDeviceAddr ! nullptr) { aclrtFree(mmXScaleDeviceAddr); } if (mmWScaleDeviceAddr ! nullptr) { aclrtFree(mmWScaleDeviceAddr); } if (mmYDeviceAddr ! nullptr) { aclrtFree(mmYDeviceAddr); } if (workspaceSize 0) { aclrtFree(workspaceAddr); } HcclCommDestroy(args.hcclComm); aclrtDestroyStream(args.stream); aclrtDestroyContext(args.context); aclrtResetDevice(args.rankId); return 0; } int main(int argc, char *argv[]) { int ret aclInit(nullptr); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclInit failed. ret %d \n, ret); return ret); aclrtStream stream[EP_WORLD_SIZE]; aclrtContext context[EP_WORLD_SIZE]; for (uint32_t rankId 0; rankId EP_WORLD_SIZE; rankId) { ret aclrtSetDevice(rankId); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclrtSetDevice failed. ret %d \n, ret); return ret); ret aclrtCreateContext(context[rankId], rankId); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclrtCreateContext failed. ret %d \n, ret); return ret); ret aclrtCreateStream(stream[rankId]); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] aclrtCreateStream failed. ret %d \n, ret); return ret); } int32_t devices[EP_WORLD_SIZE]; for (int i 0; i EP_WORLD_SIZE; i) { devices[i] i; } //初始化集合通信域 HcclComm comms[EP_WORLD_SIZE]; ret HcclCommInitAll(EP_WORLD_SIZE, devices, comms); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT([ERROR] HcclCommInitAll failed. ret %d \n, ret); return ret); Args args[EP_WORLD_SIZE]; // 启动多线程 std::vectorstd::unique_ptrstd::thread threads(EP_WORLD_SIZE); for (uint32_t rankId 0; rankId EP_WORLD_SIZE; rankId) { args[rankId].rankId rankId; args[rankId].hcclComm comms[rankId]; args[rankId].stream stream[rankId]; args[rankId].context context[rankId]; threads[rankId].reset(new std::thread(LaunchOneThreadAlltoAllvGmm, std::ref(args[rankId]))); } for (uint32_t rankId 0; rankId EP_WORLD_SIZE; rankId) { threads[rankId]-join(); } aclFinalize(); return 0; }【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考