Online Softmax Tail Handling on A2 Triple-Bridge Kernels【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skillsRead this file when debugging or extending an a2 (easyasc.a2, deviceb3) normalized online softmax kernel with delayedp/pvstages and a non-alignedS2tail orS1tail.Typical targets:agent/example/kernels/a2/flash_attn_full.pyagent/example/kernels/a2/flash_attn_full_pj_hif8.pyagent/example/kernels/a2/flash_attn_full_pj_hif8_causal.pyDo not use this file as the first reference for generic tail bugs. For generic GM-boundary tail rules, readagent/references/constraints/tail-safety.mdfirst. This file only covers the extra rule that appears once the kernel has:runningrow_maxrunningrow_sumdelayedexpdiffdelayedp vGoalHandle a non-alignedS2tail orS1tail without breaking the normalized online softmax math.The two axes arenotsymmetric:S2tail means invalidcolumnsinside otherwise valid score rowsS1tail means invalidrowsinside an otherwise full local score tileThe stable rules are therefore different:S2still usesvalid_nat GM boundaries, but also needs score-domain-infmasking beforerowmaxS1usesvalid_mat GM boundaries, then masks only the local invalid rows afterscore - rowmaxand beforeexp1. Why GM-boundary slicing alone is not enoughThe generic tail rule still applies:local tensors stay full-tile sizedonly GM loads/stores usevalid_nThat prevents out-of-bounds reads, but it isnotenough for online softmax.If the lastk/vtile is loaded withvalid_n TILE_N, the padded columns look like zeros in the staged full tile. That creates a second problem:rowmax(score_j)can see the padded columnscurr_m maximum(prev_m, rowmax(score_j))can become too largeexpdiff_j exp(prev_m - curr_m)then rescales previous accumulated state incorrectlyrow_sumandoutare both corrupted even if laterp_jis masked to zeroSo for normalized online softmax:padded tail columns must behave like-infbeforerowmaxthe same padded columns then naturally become0afterexp2. Do not start from ap-domain-only fixAp-domain-only tail mask is insufficient for normalized online softmax.It can fix:delayedp vany later use ofp_jIt cannot fix:rowmax(score_j)curr_mdelayedexpdiff_jrow_sumIf the kernel has runningrow_max/row_sum, fix the score tile first.3. Stable semantic rule for invalid tail columnsFor the lastS2tile:beforecmax: invalid columns must look like-infafterexp: invalid columns must behave like0This rule preserves the exact reference update:curr_m maximum(prev_m, rowmax(score_j_valid_only))p_j exp(score_j_valid_only - curr_m)row_sum row_sum * expdiff_j p_j.sum(-1)You donotneed a separatep-domain tail mask if the score tile already uses this-infrule and the delayedvload also usesvalid_n.4. Stable a2 implementation shapeFor the current validated flash-attention kernels:TILE_N 128score is processed in vec as[HALF_M, TILE_N]the practical split is two[HALF_M, 64]halvesThat gives a stable rule:left half handles columns[0:64)right half handles columns[64:128)Tail cases:valid_n 128both halves fully valid64 valid_n 128left half fully validright half needs a suffix invalid maskvalid_n 64left half fully validright half fully invalid0 valid_n 64left half needs a suffix invalid maskright half fully invalidvalid_n 0both halves fully invalid5. Why vec mask finite negative sentinel is the simplest score-domain fixForfloatvec ops on a2:the active mask prefix length is64the same64-lane mask prefix is reused for each repeatThat matches a[HALF_M, 64]score half perfectly:one row uses one repeateach row wants the same tail-column maskSo the stable suffix invalidation pattern is:compute a 64-bit suffix-invalid maskset_mask(0, low_mask)dup(score_half, neg_large)reset_mask()This is usually simpler than materializing a[HALF_M, 64]flag tensor and then doingselect(...)on the score half. The intent is still-infbehavior, but the concrete fill should stay finite on hardware paths.Read next for exact vec mask semantics:agent/references/constraints/mask.md6. Bit order and mask meaningInstruction semantics:lowwritesmask[0:64]bit0maps to the lowest logical lane in that prefixbit63maps to the highest logical lane in that prefixStub call note:the current a2 stub is called asset_mask(mask_high, mask_low)so a low-only score-half mask is written withset_mask(0, low_mask)So for a suffix invalid mask on one 64-column score half:columns[0:valid_cols)should be0columns[valid_cols:64)should be1Examples:valid_cols 64- no invalid bitsvalid_cols 63- only bit63is1valid_cols 10- bits[10:63]are1valid_cols 0- all bits are1Validated repository tests:testcases/simulator/micro/test_simulator_v2_muladddst_mask.pytestcases/simulator/micro/test_simulator_v2_vec_ops_extended.py7. Stable scalar-mask construction trickThe obvious unsigned construction:build a hugeuint64value like18446744073709550592can trip the simulators scalar cast path because the current runtime first creates a Python/Torch signed integer before converting touint64.The stable workaround is:start from signed-1left-shift itvalid_colstimesthen assign the signed result into auint64VarFor one 64-lane score half this builds the same suffix-invalid bit pattern:func() def build_suffix_invalid_mask(valid_cols: Var, out_mask: Var): signed_mask Var(-1, DT.int64) two_i64 Var(2, DT.int64) for _ in range(0, valid_cols): signed_mask signed_mask * two_i64 out_mask signed_maskWhy this works:-1 valid_colsequals the desired suffix-invalid mask in twos-complementthe intermediate signed values stay representable inint64the finaluint64assignment preserves the bit pattern8. Minimal integration recipeFor a normalized online softmax stage-1 score tile:loadkwithvalid_nstage the full[HALF_M, TILE_N]score tileapply score-tail masking only whenvalid_n TILE_Nonly then run:vmax(...)cmax(...)delayedexpdiffexp(...)cadd(...)stage delayedplater loadvwith the recomputed previous-tilevalid_nThe score-tail masking point should be:after scale is appliedbefore anyrowmax/cmax9. Minimal validation setDo not validate only aligned cases.ForTILE_N 128, keep at least:one aligned baseline:S2 % 128 0one small left-half tail:S2 % 128 10one first-right-half case:S2 % 128 65one mid-right-half case:S2 % 128 96one last-column case:S2 % 128 127Forflash_attn_full_pj_hif8.py, the validated runnable regression lives in the kernel self-check:agent/example/kernels/a2/flash_attn_full_pj_hif8.py10. WhyS1tail is a different problemDo not try to solveS1tail by reusing theS2column-tail mental model.ForS1tail:the invalid region is a suffix ofrows, not columnsqmust usevalid_mat the GM boundaryfinaloutmust also usevalid_mat the GM boundarythe vec side still sees a fixed physical[HALF_M, TILE_N]score tileCurrent validated a2 flash-attention shape:the two vec subblocks read fixed physical row rangessubblock0reads rows[0:64)subblock1reads rows[64:128)this isnotthe a5-styleCeilDiv(valid_m, 2)compact half splitSo the stable local quantity is:local_valid_m clamp(valid_m - sb_row, 0, HALF_M)where:valid_mis the tile-level valid query-row countsb_rowis the fixed physical subblock row origin (0or64)11. StableS1implementation ruleFor a normalized online softmax stage-1 score tile withS1tail:loadqwithvalid_mrely on the currentgm_to_l1_nd2nzzero-fill behavior for the local tail rowsrun the normal score tile,rowmax,curr_m, andexpdiffflow on the full local score tileafterscore_j - curr_m, but beforeexp(score_j), overwrite the local invalid row suffix with a sufficiently negative finite sentinelkeep the delayedp/pvpath full-tile sizedwrite back onlylocal_valid_mrows to GMWhy this point is stable:masking invalid rowsbeforecmaxcan create invalid-row sentinelrowmaxand unstable invalid-row subtraction behavior analogous to-inf - (-inf)masking themaftersubtractingcurr_mpreserves the valid-row online softmax maththe invalid local rows then become0afterexp, so they contribute nothing to delayedp vCurrent repository tolerance:invalidS1tail rows may still becomeNaNafter the finalout / row_sumon local UB rowsthis is acceptable because those rows are not written back to GM12. MinimalS1validation setDo not validate only one row-tail case.ForTILE_M 128, keep at least:one aligned baseline:S1 % 128 0one one-row tail:S1 % 128 1one last-row-in-first-half case:S1 % 128 63one exact half case:S1 % 128 64one first-row-in-second-half case:S1 % 128 65one last-row case:S1 % 128 127one larger shape beyond two tiles, for exampleS1 257one multi-head shapeKeepS2aligned while validating the newS1path first, so failures are easier to attribute.13. Causal diagonal tiles on a2Read this when extending the same normalized online-softmax pipeline from plain tail handling to left-up causal masking (k_pos q_pos).The stable tile classification is:nt lmt: the tile is fully validnt lmt: the tile is diagonal and contains mixed valid/invalid columnsnt lmt: the tile is fully invalid and should be skippedFor the current validated causal kernel, the stable scheduling rule is:clamp the stage-1/stage-2 loop toactive_tiles_n Min(tiles_n, lmt 1)still keep the outern_loops 1style drain shape by iterating toactive_tiles_n 1this preserves the delayedp vflush while removing future fully-invalid tilesThe diagonal tile isnota plainvalid_ntail:invalid columns vary by rowthe stable local quantity isvalid_cols sb_row row 1sb_rowis the fixed subblock row origin (0or64)Stable implementation rule for the diagonal tile:load and scale the full[HALF_M, TILE_N]score tileprebuild reusable packed-bit masks once per subblock before the main tile loop:causal_mask_left: Tensor(DT.uint8, [HALF_M, HALF_N // 8], Position.UB)causal_mask_right: Tensor(DT.uint8, [HALF_M, HALF_N // 8], Position.UB)initialize one reusable integer column-index row for[0, 1, ..., 63]; the current validated kernel writes twoint32entries at a time through anint64reinterpret to keepSetValueTo(...)count lowuse a Python-unrolled row loop (py_range(HALF_M)) only for the per-row threshold, and synthesize packed mask bytes withcompare_scalar(...):ifsb_row 0, build onlycausal_mask_left[row]with thresholdrow 1ifsb_row 64, fillcausal_mask_leftto all ones and build onlycausal_mask_right[row]with thresholdrow 1apply the packed masks withselect(..., SelectMode.TENSOR_SCALAR)beforecmax/rowmaxif the same tile is also the finalS2tail tile, apply the diagonal causal mask first andvalid_ntail masking secondWhy this is the stable path:it matches the current hardware / simulator rule thatcompare_scalar(...)andselect(...)use packed-bituint8controlit keeps the control path cheap by building the causal masks once per subblock instead of reconstructing them inside every diagonal-tile visitit avoids the large simulator overhead of byte-by-byteSetValueTo(...)loops for mask constructionit avoids trying to repair causal semantics later in theporpvpathit preserves the exact online-softmax updates forrow_max,expdiff, androw_sumMinimal causal validation set:oneS1 S2aligned caseoneS1 S2unaligned caseoneS1 S2caseoneS1 S2caseone multi-head caseValidated runnable example:agent/example/kernels/a2/flash_attn_full_pj_hif8_causal.py14. Files to studyagent/example/kernels/a2/flash_attn_full_pj_hif8.pyagent/example/kernels/a2/flash_attn_full_pj_hif8_causal.pytestcases/simulator/micro/test_simulator_v2_muladddst_mask.pytestcases/simulator/micro/test_simulator_v2_vec_ops_extended.pyagent/references/constraints/tail-safety.mdagent/references/constraints/mask.mdagent/references/patterns/a2-cube-vec-cube-vec-softmax.md【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考