MLIR专题10:下译LLVM IR
我们还是以toy例子为实例,完整的代码如下://===----------------------------------------------------------------------===// // ToyToLLVM RewritePatterns //===----------------------------------------------------------------------===// namespace { /// Lowers `toy.print` to a loop nest calling `printf` on each of the individual /// elements of the array. class PrintOpLowering : public ConversionPattern { public: explicit PrintOpLowering(MLIRContext *context) : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(Operation *op, ArrayRefValue operands, ConversionPatternRewriter rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::castMemRefType((*op-operand_type_begin())); auto memRefShape = memRefType.getShape(); auto loc = op-getLoc(); ModuleOp parentModule = op-getParentOfTypeModuleOp(); // Get a symbol reference to the printf function, inserting it if necessary. auto printfRef = getOrInsertPrintf(rewriter, parentModule); Value formatSpecifierCst = getOrCreateGlobalString( loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule); Value newLineCst = getOrCreateGlobalString( loc, rewriter, "nl", StringRef("\n\0", 2), parentModule); // Create a loop for each of the dimensions within the shape. SmallVectorValue, 4 loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { auto lowerBound = rewriter.createarith::ConstantIndexOp(loc, 0); auto upperBound = rewriter.createarith::ConstantIndexOp(loc, memRefShape[i]); auto step = rewriter.createarith::ConstantIndexOp(loc, 1); auto loop = rewriter.createscf::ForOp(loc, lowerBound, upperBound, step); for (Operation nested : make_early_inc_range(*loop.getBody())) rewriter.eraseOp(nested); loopIvs.push_back(loop.getInductionVar()); // Terminate the loop body. rewriter.setInsertionPointToEnd(loop.getBody()); // Insert a newline after each of the inner dimensions of the shape. if (i != e - 1) rewriter.createLLVM::CallOp(loc, getPrintfType(context), printfRef, newLineCst); rewriter.createscf::YieldOp(loc); rewriter.setInsertionPointToStart(loop.getBody()); } // Generate a call to printf for the current element of the loop. auto printOp = casttoy::PrintOp(op); auto elementLoad = rewriter.creatememref::LoadOp(loc, printOp.getInput(), loopIvs); rewriter.createLLVM::CallOp( loc, getPrintfType(context), printfRef, ArrayRefValue({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); return success(); } private: /// Create a function declaration for printf, the signature is: /// * `i32 (i8*, ...)` static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { auto llvmI32Ty = IntegerType::get(context, 32); auto llvmPtrTy = LLVM::LLVMPointerType::get(context); auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, /*isVarArg=*/true); return llvmFnType; } /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter rewriter, ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbolLLVM::LLVMFuncOp("pr