CANN社区Median算子设计
需求背景required【免费下载链接】cann-ops-competitions本仓库用于 CANN 开源社区各类竞赛、开源课题、社区任务等课题发布、开发者作品提交和展示。项目地址: https://gitcode.com/cann/cann-ops-competitions需求来源本任务来源于昇腾CANN社区任务2026任务序号04-9要求参考torch.median功能基于Ascend C编程语言在昇腾NPU上实现功能一致的Median算子对齐aclnnMedian与aclnnMedianDim所有走入aicore的数据类型替代现有小算子拼接实现验收通过后合入昇腾算子开源仓ops-nn/experimental/index。背景介绍Median算子实现优化基于Median算子历史小算子拼接版本使用Ascend C编程语言进行重构和优化。对标接口torch.median(input)/torch.median(input, dim, keepdim)参考实现https://gitcode.com/cann/ops-nn/blob/master/index/gather_v2/op_api/aclnn_median.cpp开源仓地址https://gitcode.com/cann/ops-nnMedian算子拼接版本现状分析torch.median有两种调用形态官方均通过小算子拼接实现①aclnnMedian全局中位数torch.median(input)将输入展平为1D后整体排序取下中位数n个元素排序后下标(n-1)/2返回标量。 ②aclnnMedianDim按轴中位数torch.median(input, dim)沿dim轴排序取下中位数及其原始索引返回(values, indices)。需特别支持dim轴维度为1的退化场景。 ③ 偶数长度取下中位数lower median与NumPy取均值语义不同须严格对齐PyTorch。 ④ 拼接路径Contiguous → 转dim到末维 → Sort全排序argsort→ Gather/Slice取中点与索引多个独立kernel完成。当前拼接版本存在的核心问题全量Sort 中间张量落盘导致访存冗余取中位数仅需第k小拼接版却做O(n·log n)全排序Sort的values/indices两份中间张量写回GM再Gather取一行HBM往返多、kernel启动多。reduce长度小、batch多时kernel启动延迟主导带宽浪费严重。Ascend C可通过单kernel融合解决CopyIn一次性把一条reduce段搬入UBCompute用quickselect只求第k小免全排序、中间结果驻留UBCopyOut一次写出indices回扫匹配首个等值下标。整体1次kernel、2次GM↔UB搬运。拼接版整体流程如下图所示算子功能规格规格项描述算子名称aclnnMedian / aclnnMedianDim全局中位数展平排序取下标(n-1)/2输出标量按轴中位数沿dim取下中位数输出values与首个匹配indices输入input (Tensor)支持fp16/fp32/bf16/int16/int32/int64/uint8/int8属性dim (int)、keepdim (bool)MedianDim专用输出values同input dtype、indicesint64仅MedianDim算子原型名称类别dtypeformat介绍input输入fp16/fp32/bf16/int系列ND任意shapedim属性int-归约轴MedianDimkeepdim属性bool-是否保留轴values输出同inputND全局标量/降一维indices输出int64ND下中位数原始下标首个相关约束Atlas A2 训练系列产品 / Atlas A3 系列产品偶数长度取下中位数与PyTorch一致indices取首个等值须支持dim轴维度1支持泛化各类合法shape需求分析required外部组件依赖不涉及外部组件依赖。内部适配模块适配aclnnMedian/aclnnMedianDim接口及图模式调用。需求描述使用Ascend C实现Median全局与按轴两形态对齐torch.median下中位语义用单kernel选第k小替代全排序Gather拼接降低kernel启动与GM搬运扩展并对齐所有走aicore的数据类型。需求拆解全局Median展平选第(n-1)/2小输出标量按轴MedianDim沿dim选下中位数首个索引支持keepdim与dim1quickselect免全排序bf16核内Cast fp32比较再回写精度满足AscendOpTest默认阈值性能不劣于小算子拼接版详细设计required算子分析数学公式$$\text{median} \text{sorted}(x){\lfloor (n-1)/2 \rfloor},\quad \text{MedianDim: } y{...} \text{sorted}(x_{...,,dim})_{\lfloor (k-1)/2 \rfloor}$$算子特性归约/选择输出依赖整条reduce段非逐元素适合一核处理整行下中位偶数取下标(k-1)/2indices取首个等值确定可复现免全排序只求第k小O(n)选择优于O(n·log n)排序无跨行依赖行间独立按行分核无需跨核同步支持数据类型fp16、fp32、bf16、int16、int32、int64、uint8、int8bf16核内转fp32算子实现整体架构Host侧设计mid计算mid(k-1)/2k为全局总数或dim长度分核按输出行数均分core偶数对齐行超UB走workspace分块TilingKey0全局1MedianDim带indices2dim1直通TilingDatatotalRows/redLen/mid/tileLen/formerNum/tailNumKernel侧设计CopyIn一次搬入整行bf16先Cast fp32Computequickselect求第k小valuesindices再扫一遍取首个等值下标CopyOut写valuesindices官方拼接 vs 我方融合差异差异点拼接版Ascend C原因算法全排序O(nlogn)选第k小O(n)中位只需k-th中间张量sortargsort落GMUB驻留免HBM往返kernelSortGather多次1次融合减启动dim1仍排序直通退化优化支持硬件支持的芯片版本涉及勾选Atlas A2 训练系列产品√Atlas A3 系列产品√使能方式上层框架涉及勾选Pytorch训练/推理√Aclnn直调√算子约束限制dim须支持维度1偶数取下中位indices取首个等值超大行经workspace特性交叉分析归约/选择算子行间独立不涉及广播/量化冲突。可维可测分析精度标准/性能标准验收标准描述标准来源精度标准满足AscendOpTest默认阈值int全等fp16/bf16 rtol≈1e-3fp32≈1e-4AscendOpTest性能标准不劣于aclnnMedian/MedianDim小算子拼接版社区任务要求兼容性分析提供与PyTorchtorch.median等效功能aclnn接口遵循CANN规范新算子不涉及兼容性问题。【免费下载链接】cann-ops-competitions本仓库用于 CANN 开源社区各类竞赛、开源课题、社区任务等课题发布、开发者作品提交和展示。项目地址: https://gitcode.com/cann/cann-ops-competitions创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考