-
Notifications
You must be signed in to change notification settings - Fork 15.8k
[mlir][emitc] Fix recurring operands in expression #175535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][emitc] Fix recurring operands in expression #175535
Conversation
The pretty-printing for `emitc.expression` breaks for expressions taking the same value as operand multiple times. Passing the same value as operand more than once is redundant, and is therefore not the canonical form of `emitc.expression. However, since transformations affecting `emitc.expression` operands may cause this to happen, `emitc.expression` must retain its support for recurring operands. This PR fixes this issue by shadowing the region arguments only when the operands are unique, printing and parsing an explicit basic block otherwise. In addition, a canonicalization pattern removing recurring operands is added. Fixes issue llvm#172952.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-emitc Author: Gil Rapaport (aniragil) ChangesThe pretty-printing for Passing the same value as operand more than once is redundant, and is therefore not the canonical form of This PR fixes this issue by shadowing the region arguments only when the operands are unique, printing and parsing an explicit basic block otherwise. In addition, a canonicalization pattern removing recurring operands is added. Fixes issue #172952. Full diff: https://github.com/llvm/llvm-project/pull/175535.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index c1820904f2665..b638130e24b24 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -566,6 +566,7 @@ def EmitC_ExpressionOp
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
+ let hasCanonicalizer = 1;
let extraClassDeclaration = [{
bool hasSideEffects() {
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index bdf6d0985e6db..67b17d3c0d573 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -25,6 +25,10 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
// Populate functions
//===----------------------------------------------------------------------===//
+/// Populates `patterns` with expression canonicalization patterns.
+void populateExpressionCanonicalizationPatterns(RewritePatternSet &patterns,
+ MLIRContext *context);
+
/// Populates `patterns` with expression-related patterns.
void populateExpressionPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b0566dd10f490..3d49ec0d3a78a 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -8,6 +8,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -412,6 +413,11 @@ LogicalResult DereferenceOp::verify() {
// ExpressionOp
//===----------------------------------------------------------------------===//
+void ExpressionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ populateExpressionCanonicalizationPatterns(results, context);
+}
+
ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> operands;
if (parser.parseOperandList(operands))
@@ -435,27 +441,45 @@ ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
"expected single return type");
result.addTypes(fnType.getResults());
Region *body = result.addRegion();
+ DenseSet<Value> uniqueOperands(result.operands.begin(),
+ result.operands.end());
+ bool enableNameShadowing = uniqueOperands.size() == result.operands.size();
SmallVector<OpAsmParser::Argument> argsInfo;
- for (auto [unresolvedOperand, operandType] :
- llvm::zip(operands, fnType.getInputs())) {
- OpAsmParser::Argument argInfo;
- argInfo.ssaName = unresolvedOperand;
- argInfo.type = operandType;
- argsInfo.push_back(argInfo);
+ if (enableNameShadowing) {
+ for (auto [unresolvedOperand, operandType] :
+ llvm::zip(operands, fnType.getInputs())) {
+ OpAsmParser::Argument argInfo;
+ argInfo.ssaName = unresolvedOperand;
+ argInfo.type = operandType;
+ argsInfo.push_back(argInfo);
+ }
}
- if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
+ SMLoc beforeRegionLoc = parser.getCurrentLocation();
+ if (parser.parseRegion(*body, argsInfo, enableNameShadowing))
return failure();
+ if (!enableNameShadowing) {
+ if (body->front().getArguments().size() < result.operands.size()) {
+ return parser.emitError(
+ beforeRegionLoc, "with recurring operands expected block arguments");
+ }
+ }
return success();
}
void emitc::ExpressionOp::print(OpAsmPrinter &p) {
p << ' ';
- p.printOperands(getDefs());
+ auto operands = getDefs();
+ p.printOperands(operands);
p << " : ";
p.printFunctionalType(getOperation());
- p.shadowRegionArgs(getRegion(), getDefs());
+ DenseSet<Value> uniqueOperands(operands.begin(), operands.end());
+ bool printEntryBlockArgs = true;
+ if (uniqueOperands.size() == operands.size()) {
+ p.shadowRegionArgs(getRegion(), getDefs());
+ printEntryBlockArgs = false;
+ }
p << ' ';
- p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+ p.printRegion(getRegion(), printEntryBlockArgs);
}
Operation *ExpressionOp::getRootOp() {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index f8469b8f0ed67..bfcb4a140ee9f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -149,8 +149,60 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
}
};
+struct RemoveRecurringExpressionOperands
+ : public OpRewritePattern<ExpressionOp> {
+ using OpRewritePattern<ExpressionOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(ExpressionOp expressionOp,
+ PatternRewriter &rewriter) const override {
+ SetVector<Value> uniqueOperands;
+ DenseMap<Value, int> firstIndexOf;
+
+ // Collect duplicate operands and prepare to remove excessive copies.
+ for (auto [i, operand] : llvm::enumerate(expressionOp.getDefs())) {
+ if (uniqueOperands.contains(operand))
+ continue;
+ uniqueOperands.insert(operand);
+ firstIndexOf[operand] = i;
+ }
+
+ // If every operand is unique, bail out.
+ if (uniqueOperands.size() == expressionOp.getDefs().size())
+ return failure();
+
+ // Create a new expression with unique operands.
+ rewriter.setInsertionPointAfter(expressionOp);
+ auto uniqueExpression = emitc::ExpressionOp::create(
+ rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
+ uniqueOperands.getArrayRef(), expressionOp.getDoNotInline());
+ Block &uniqueExpressionBody = uniqueExpression.createBody();
+
+ // Map each original block arguments to the unique block argument taking
+ // the same operand.
+ IRMapping mapper;
+ Block *expressionBody = expressionOp.getBody();
+ for (auto [operand, arg] :
+ llvm::zip(expressionOp.getOperands(), expressionBody->getArguments()))
+ mapper.map(arg, uniqueExpressionBody.getArgument(firstIndexOf[operand]));
+
+ rewriter.setInsertionPointToStart(&uniqueExpressionBody);
+ for (Operation &opToClone : *expressionOp.getBody())
+ rewriter.clone(opToClone, mapper);
+
+ // Complete the rewrite.
+ rewriter.replaceOp(expressionOp, uniqueExpression);
+
+ return success();
+ }
+};
+
} // namespace
+void mlir::emitc::populateExpressionCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<RemoveRecurringExpressionOperands>(patterns.getContext());
+}
+
void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
+ populateExpressionCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<FoldExpressionOp>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/EmitC/form-expressions.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir
index 7b6723989e260..58eac4381ccb7 100644
--- a/mlir/test/Dialect/EmitC/form-expressions.mlir
+++ b/mlir/test/Dialect/EmitC/form-expressions.mlir
@@ -20,6 +20,25 @@ func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
return %c : i1
}
+// CHECK-LABEL: func.func @expression_recurring_args(
+// CHECK-SAME: %[[ARG0:.*]]: i32,
+// CHECK-SAME: %[[ARG1:.*]]: i32) -> i1 {
+// CHECK: %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG1]], %[[ARG0]] : (i32, i32) -> i1 {
+// CHECK: %[[VAL_0:.*]] = mul %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CHECK: %[[VAL_1:.*]] = sub %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CHECK: %[[VAL_2:.*]] = cmp lt, %[[VAL_1]], %[[ARG1]] : (i32, i32) -> i1
+// CHECK: yield %[[VAL_2]] : i1
+// CHECK: }
+// CHECK: return %[[EXPRESSION_0]] : i1
+// CHECK: }
+
+func.func @expression_recurring_args(%arg0: i32, %arg1: i32) -> i1 {
+ %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.sub %a, %arg0 : (i32, i32) -> i32
+ %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1
+ return %c : i1
+}
+
// CHECK-LABEL: func.func @multiple_expressions(
// CHECK-SAME: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) {
// CHECK: %[[VAL_4:.*]] = emitc.expression %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : (i32, i32, i32) -> i32 {
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index d1601bed29ca9..0d878e90cdf0c 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -379,6 +379,19 @@ emitc.func @test_expression_op_outside_expression() {
// -----
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+ // expected-error @+1 {{'emitc.expression' with recurring operands expected block arguments}}
+ %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+ %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
+ %b = emitc.add %a, %arg0 : (i32, i32) -> i32
+ %c = emitc.mul %b, %a : (i32, i32) -> i32
+ emitc.yield %c : i32
+ }
+ return %r : i32
+}
+
+// -----
+
// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}}
emitc.func @multiple_results(%0: i32) -> (i32, i32) {
emitc.return %0 : i32
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index b2c8b843ec14b..2f7544b5db096 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | FileCheck -check-prefix=CANON %s
// CHECK: emitc.include <"test.h">
// CHECK: emitc.include "test.h"
@@ -213,6 +213,28 @@ func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
return %r : i32
}
+// CANON-LABEL: func.func @test_expression_recurring_operands(
+// CANON-SAME: %[[ARG0:.*]]: i32,
+// CANON-SAME: %[[ARG1:.*]]: i32) -> i32 {
+// CANON: %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 {
+// CANON: %[[VAL_0:.*]] = rem %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CANON: %[[VAL_1:.*]] = add %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CANON: %[[VAL_2:.*]] = mul %[[VAL_1]], %[[VAL_0]] : (i32, i32) -> i32
+// CANON: yield %[[VAL_2]] : i32
+// CANON: }
+// CANON: return %[[EXPRESSION_0]] : i32
+// CANON: }
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+ %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+ ^bb0(%x: i32, %y: i32, %z: i32):
+ %a = emitc.rem %x, %y : (i32, i32) -> i32
+ %b = emitc.add %a, %z : (i32, i32) -> i32
+ %c = emitc.mul %b, %a : (i32, i32) -> i32
+ emitc.yield %c : i32
+ }
+ return %r : i32
+}
+
func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
emitc.for %i0 = %arg0 to %arg1 step %arg2 {
%0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32
|
The pretty-printing for
emitc.expressionbreaks for expressions taking the same value as operand multiple times.Passing the same value as operand more than once is redundant, and is therefore not the canonical form of
emitc.expression. However, since transformations affectingemitc.expressionoperands may cause this to happen,emitc.expression` must retain its support for recurring operands.This PR fixes this issue by shadowing the region arguments only when the operands are unique, printing and parsing an explicit basic block otherwise. In addition, a canonicalization pattern removing recurring operands is added.
Fixes #172952.