PVN3D Custom ONNX Op / TensorRT Plugin 设计说明
1. 目标本文用于给 PVN3D 的PointNet2自定义 CUDA 算子定义一套可落地的部署侧设计目标是支持导出完整ONNX支持构建完整TensorRT engine部署运行时不再依赖PyTorch和pvn3d.lib.pointnet2_utils._ext本文不是泛泛而谈的 PointNet2 设计文档而是直接结合当前仓库里的真实实现Python 封装pointnet2_utils.pypointnet2_modules.pyC/CUDA 扩展bindings.cppsampling.cppsampling_gpu.cuball_query.cppball_query_gpu.cugroup_points.cppgroup_points_gpu.cuinterpolate.cppinterpolate_gpu.cu2. 版本基线本文对应的部署基线沿用当前容器文档中的已核验版本Python3.8.20PyTorch1.10.0cu113ONNX1.14.1ONNX Runtime1.16.3TensorRT8.6.1这里的ONNX - TRT默认指TensorRT8.6.1自带 ONNX parsertrtexec或 TensorRT builder API而不是依赖单独的onnx-tensorrtPython 包。3. 现有 CUDA 扩展暴露出的算子集合从 bindings.cpp 可以确认当前_ext实际导出的是gather_pointsgather_points_gradfurthest_point_samplingthree_nnthree_interpolatethree_interpolate_gradball_querygroup_pointsgroup_points_grad对部署前向链真正有意义的是furthest_point_samplinggather_pointsball_querygroup_pointsthree_nnthree_interpolate*_grad相关接口只服务训练反向对部署态完整 ONNX / TRT engine 不必实现。4. 场景描述这些算子在 PVN3D 里到底干什么4.1 场景 ASet Abstraction 路径对应模块pointnet2_modules.py在PointnetSAModuleMSG.forward()里PointNet2 做的是从原始点云xyz[B,N,3]中选出npoint个中心点以这些中心点为球心在半径radius内搜邻域点对每个中心点收集固定数量nsample的邻域索引用这些索引把xyz和features分组为局部 patch对局部 patch 做 MLP 和池化也就是说这条链不是普通 CNN 的卷积 receptive field而是点云上的“显式邻域构造”furthest_point_samplegather_operationball_querygrouping_operation这是第一类场景。4.2 场景 BFeature Propagation 路径对应模块pointnet2_modules.py在PointnetFPModule.forward()里PointNet2 做的是对稀疏点集和稠密点集做 3 近邻搜索计算三近邻的距离倒数权重按三近邻索引把稀疏特征插值回稠密点集对应链路是three_nnthree_interpolate这是第二类场景。4.3 为什么这两类场景必须单独描述因为它们决定了 custom op 的接口边界SA路径主要是“离散索引生成 batched gather”FP路径主要是“近邻搜索 加权插值”如果不先按场景分开后面容易把 plugin 设计做成一堆彼此耦合、难以调试的杂合层。5. 从 CUDA 源码提炼出来的真实语义这部分非常关键因为 custom op 的行为必须和现有 CUDA kernel 保持一致。5.1furthest_point_sampling从 sampling_gpu.cu 可以确认输入是dataset[b, n, 3]输出是idxs[b, m]首个采样点固定从old 0开始temp[b, n]初始化为大数1e10迭代过程中使用“当前点到已选点集合的最小距离”的最大值作为下一个点对模长mag 1e-3的点直接跳过这意味着第一版 custom op 设计里要明确采样起点不是随机的而是固定0存在“近零点跳过”行为算子行为依赖固定的距离更新逻辑而不是任意 FPS 变体5.2ball_query从 ball_query_gpu.cu 可以确认输入是new_xyz[b, m, 3]xyz[b, n, 3]输出是idx[b, m, nsample]判定条件是d2 radius^2当找到第一个命中点k时会先把整个idx[j, :]全部填成k后续继续找到命中点时再按cnt位置覆写前面的槽位这意味着邻域不足nsample时不是填-1而是重复第一个命中的索引这是非常重要的行为约定custom op 和 TRT plugin 必须保持一致。5.3group_points从 group_points_gpu.cu 可以确认输入points[b, c, n]idx[b, npoints, nsample]输出out[b, c, npoints, nsample]本质就是按idx从points上取值并重排。它本身没有搜索语义依赖上游idx的定义是否一致。5.4gather_points从 sampling_gpu.cu 可以确认输入points[b, c, n]idx[b, m]输出out[b, c, m]它对应的是单索引 gather而不是group_points那种二维索引 gather。5.5three_nn从 interpolate_gpu.cu 可以确认输入unknown[b, n, 3]known[b, m, 3]输出dist2[b, n, 3]idx[b, n, 3]kernel 内输出的是平方距离dist2Python 封装层再对dist2做torch.sqrt()这意味着部署设计里有两个选择custom op 直接输出sqrt(dist2)对齐 Python 最终语义custom op 输出dist2再在图里补一个Sqrt推荐第一版选择第 2 种plugin 更贴近现有 CUDA kernel方便先保持实现简单5.6three_interpolate从 interpolate_gpu.cu 可以确认输入points[b, c, m]idx[b, n, 3]weight[b, n, 3]输出out[b, c, n]其语义严格就是三个索引位置的特征值乘对应权重后求和这意味着如果上游three_nn已经产出稳定的idx那么three_interpolate比搜索类算子更规则。6. Custom ONNX op 域与命名建议统一使用固定 domaincom.pvn3d.pointnet2建议的 op 名如下PVN3D_FurthestPointSamplePVN3D_GatherPointsPVN3D_BallQueryPVN3D_GroupPointsPVN3D_ThreeNNPVN3D_ThreeInterpolate这样做的目的和 TensorRT plugin 名一一对应后续图检查时容易定位不和其他第三方 PointNet2 实现混淆7. 每个 custom op 的接口设计7.1PVN3D_FurthestPointSample场景SA路径选中心点输入xyz:float[B, N, 3]属性npoint: int输出idx:int32[B, npoint]约束第一版只支持 static shape第一版只支持B1第一版只支持float32输入语义约定起始索引固定为0保持现有 kernel 的“近零点跳过”行为TensorRT plugin 建议名PVN3DFurthestPointSample_TRT7.2PVN3D_GatherPoints场景SA路径根据 FPS 输出索引取中心点坐标输入points:float[B, C, N]idx:int32[B, S]输出out:float[B, C, S]第一版建议先尝试标准 ONNXGather改写只有在 TRT parser 或运行期不稳定时才升级成 plugin如果最终也做 pluginplugin 名建议PVN3DGatherPoints_TRT7.3PVN3D_BallQuery场景SA路径在每个中心点周围搜固定邻域输入new_xyz:float[B, S, 3]xyz:float[B, N, 3]属性radius: floatnsample: int输出idx:int32[B, S, nsample]语义约定距离判定条件是d2 radius^2邻域不足时重复第一个命中索引不使用-1作为 paddingTensorRT plugin 建议名PVN3DBallQuery_TRT7.4PVN3D_GroupPoints场景SA路径把邻域索引转换成局部 patch 特征张量输入points:float[B, C, N]idx:int32[B, S, K]输出out:float[B, C, S, K]第一版建议先尝试标准图改写如果失败再做 plugin如果最终 plugin 化plugin 名建议PVN3DGroupPoints_TRT7.5PVN3D_ThreeNN场景FP路径从稀疏点集向稠密点集回传特征前先找 3 近邻输入unknown:float[B, N, 3]known:float[B, M, 3]输出dist2:float[B, N, 3]idx:int32[B, N, 3]说明推荐 custom op 直接输出dist2后续图里补Sqrt这样最贴近当前 CUDA kernelTensorRT plugin 建议名PVN3DThreeNN_TRT7.6PVN3D_ThreeInterpolate场景FP路径根据三近邻和权重做特征插值输入points:float[B, C, M]idx:int32[B, N, 3]weight:float[B, N, 3]输出out:float[B, C, N]第一版建议优先尝试标准图改写若在 TRT 8.6.1 下不稳定再实现 plugin如果最终 plugin 化plugin 名建议PVN3DThreeInterpolate_TRT8. 第一版实现边界为了降低复杂度第一版 custom op / plugin 明确建议限制为static shapebatch1num_points4096输入特征只支持float32索引统一int32先不考虑反向不要第一版就追求dynamic shapebatch1fp16kernel 内部全支持完整训练反向这些优化都应该放在完整链路先打通之后。9. ONNX symbolic 设计建议PyTorch 侧建议为每个算子注册独立 symbolic直接映射到 custom domain 节点。例如furthest_point_sample(xyz, npoint)-com.pvn3d.pointnet2::PVN3D_FurthestPointSampleball_query(radius, nsample, xyz, new_xyz)-com.pvn3d.pointnet2::PVN3D_BallQuerythree_nn(unknown, known)-com.pvn3d.pointnet2::PVN3D_ThreeNN原则属性显式传递不在 symbolic 里藏 shape 假设shape 假设写在 export 脚本和 manifest 里10. TensorRT plugin 设计建议每个 plugin 至少要定义清楚plugin nameplugin version输入输出数目dtype 支持shape 推导serialization 字段enqueue 所需参数建议统一 version1建议 plugin namespacepvn3d.pointnet2这样后续 parser 注册和日志定位会更清楚。11. 推荐的落地顺序按当前 PVN3D 的实际风险排序建议先实现PVN3D_ThreeNN再实现PVN3D_BallQuery再实现PVN3D_FurthestPointSample之后视情况决定GatherPoints、GroupPoints、ThreeInterpolate是否也 plugin 化原因ThreeNN逻辑闭合最适合做第一批原型BallQuery能把“邻域补位策略”这类关键行为尽早固定FurthestPointSample是 SA 路径里最强约束的离散算子12. 与现有混合部署的关系当前混合部署路径的价值仍然存在它是数值对齐基线它能验证rgb_backbone和fusion_head已经没问题它让我们只需专注PointNet2场景链也就是说后续 custom op / plugin 的测试应始终对齐当前这条链原生_extPointNet2 输出custom ONNX / TRT PointNet2 输出