Skip to content

Conversation

@aniragil
Copy link
Contributor

@aniragil aniragil commented Jan 12, 2026

The pretty-printing for emitc.expression breaks for expressions taking the same value as operand multiple times.

Passing the same value as operand more than once is redundant, and is therefore not the canonical form of emitc.expression. However, since transformations affecting emitc.expressionoperands may cause this to happen,emitc.expression` must retain its support for recurring operands.

This PR fixes this issue by shadowing the region arguments only when the operands are unique, printing and parsing an explicit basic block otherwise. In addition, a canonicalization pattern removing recurring operands is added.

Fixes #172952.

The pretty-printing for `emitc.expression` breaks for expressions taking
the same value as operand multiple times.

Passing the same value as operand more than once is redundant, and is
therefore not the canonical form of `emitc.expression. However, since
transformations affecting `emitc.expression` operands may cause this
to happen, `emitc.expression` must retain its support for recurring
operands.

This PR fixes this issue by shadowing the region arguments only when the
operands are unique, printing and parsing an explicit basic block
otherwise. In addition, a canonicalization pattern removing recurring
operands is added.

Fixes issue llvm#172952.
@llvmbot
Copy link
Member

llvmbot commented Jan 12, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-emitc

Author: Gil Rapaport (aniragil)

Changes

The pretty-printing for emitc.expression breaks for expressions taking the same value as operand multiple times.

Passing the same value as operand more than once is redundant, and is therefore not the canonical form of emitc.expression. However, since transformations affecting emitc.expressionoperands may cause this to happen,emitc.expression` must retain its support for recurring operands.

This PR fixes this issue by shadowing the region arguments only when the operands are unique, printing and parsing an explicit basic block otherwise. In addition, a canonicalization pattern removing recurring operands is added.

Fixes issue #172952.


Full diff: https://github.com/llvm/llvm-project/pull/175535.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+1)
  • (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h (+4)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+34-10)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp (+52)
  • (modified) mlir/test/Dialect/EmitC/form-expressions.mlir (+19)
  • (modified) mlir/test/Dialect/EmitC/invalid_ops.mlir (+13)
  • (modified) mlir/test/Dialect/EmitC/ops.mlir (+23-1)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index c1820904f2665..b638130e24b24 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -566,6 +566,7 @@ def EmitC_ExpressionOp
 
   let hasVerifier = 1;
   let hasCustomAssemblyFormat = 1;
+  let hasCanonicalizer = 1;
 
   let extraClassDeclaration = [{
     bool hasSideEffects() {
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index bdf6d0985e6db..67b17d3c0d573 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -25,6 +25,10 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
 // Populate functions
 //===----------------------------------------------------------------------===//
 
+/// Populates `patterns` with expression canonicalization patterns.
+void populateExpressionCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                MLIRContext *context);
+
 /// Populates `patterns` with expression-related patterns.
 void populateExpressionPatterns(RewritePatternSet &patterns);
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index b0566dd10f490..3d49ec0d3a78a 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/EmitC/IR/EmitCInterfaces.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -412,6 +413,11 @@ LogicalResult DereferenceOp::verify() {
 // ExpressionOp
 //===----------------------------------------------------------------------===//
 
+void ExpressionOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  populateExpressionCanonicalizationPatterns(results, context);
+}
+
 ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand> operands;
   if (parser.parseOperandList(operands))
@@ -435,27 +441,45 @@ ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
                             "expected single return type");
   result.addTypes(fnType.getResults());
   Region *body = result.addRegion();
+  DenseSet<Value> uniqueOperands(result.operands.begin(),
+                                 result.operands.end());
+  bool enableNameShadowing = uniqueOperands.size() == result.operands.size();
   SmallVector<OpAsmParser::Argument> argsInfo;
-  for (auto [unresolvedOperand, operandType] :
-       llvm::zip(operands, fnType.getInputs())) {
-    OpAsmParser::Argument argInfo;
-    argInfo.ssaName = unresolvedOperand;
-    argInfo.type = operandType;
-    argsInfo.push_back(argInfo);
+  if (enableNameShadowing) {
+    for (auto [unresolvedOperand, operandType] :
+         llvm::zip(operands, fnType.getInputs())) {
+      OpAsmParser::Argument argInfo;
+      argInfo.ssaName = unresolvedOperand;
+      argInfo.type = operandType;
+      argsInfo.push_back(argInfo);
+    }
   }
-  if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
+  SMLoc beforeRegionLoc = parser.getCurrentLocation();
+  if (parser.parseRegion(*body, argsInfo, enableNameShadowing))
     return failure();
+  if (!enableNameShadowing) {
+    if (body->front().getArguments().size() < result.operands.size()) {
+      return parser.emitError(
+          beforeRegionLoc, "with recurring operands expected block arguments");
+    }
+  }
   return success();
 }
 
 void emitc::ExpressionOp::print(OpAsmPrinter &p) {
   p << ' ';
-  p.printOperands(getDefs());
+  auto operands = getDefs();
+  p.printOperands(operands);
   p << " : ";
   p.printFunctionalType(getOperation());
-  p.shadowRegionArgs(getRegion(), getDefs());
+  DenseSet<Value> uniqueOperands(operands.begin(), operands.end());
+  bool printEntryBlockArgs = true;
+  if (uniqueOperands.size() == operands.size()) {
+    p.shadowRegionArgs(getRegion(), getDefs());
+    printEntryBlockArgs = false;
+  }
   p << ' ';
-  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
+  p.printRegion(getRegion(), printEntryBlockArgs);
 }
 
 Operation *ExpressionOp::getRootOp() {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index f8469b8f0ed67..bfcb4a140ee9f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -149,8 +149,60 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
   }
 };
 
+struct RemoveRecurringExpressionOperands
+    : public OpRewritePattern<ExpressionOp> {
+  using OpRewritePattern<ExpressionOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(ExpressionOp expressionOp,
+                                PatternRewriter &rewriter) const override {
+    SetVector<Value> uniqueOperands;
+    DenseMap<Value, int> firstIndexOf;
+
+    // Collect duplicate operands and prepare to remove excessive copies.
+    for (auto [i, operand] : llvm::enumerate(expressionOp.getDefs())) {
+      if (uniqueOperands.contains(operand))
+        continue;
+      uniqueOperands.insert(operand);
+      firstIndexOf[operand] = i;
+    }
+
+    // If every operand is unique, bail out.
+    if (uniqueOperands.size() == expressionOp.getDefs().size())
+      return failure();
+
+    // Create a new expression with unique operands.
+    rewriter.setInsertionPointAfter(expressionOp);
+    auto uniqueExpression = emitc::ExpressionOp::create(
+        rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
+        uniqueOperands.getArrayRef(), expressionOp.getDoNotInline());
+    Block &uniqueExpressionBody = uniqueExpression.createBody();
+
+    // Map each original block arguments to the unique block argument taking
+    // the same operand.
+    IRMapping mapper;
+    Block *expressionBody = expressionOp.getBody();
+    for (auto [operand, arg] :
+         llvm::zip(expressionOp.getOperands(), expressionBody->getArguments()))
+      mapper.map(arg, uniqueExpressionBody.getArgument(firstIndexOf[operand]));
+
+    rewriter.setInsertionPointToStart(&uniqueExpressionBody);
+    for (Operation &opToClone : *expressionOp.getBody())
+      rewriter.clone(opToClone, mapper);
+
+    // Complete the rewrite.
+    rewriter.replaceOp(expressionOp, uniqueExpression);
+
+    return success();
+  }
+};
+
 } // namespace
 
+void mlir::emitc::populateExpressionCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<RemoveRecurringExpressionOperands>(patterns.getContext());
+}
+
 void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
+  populateExpressionCanonicalizationPatterns(patterns, patterns.getContext());
   patterns.add<FoldExpressionOp>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/EmitC/form-expressions.mlir b/mlir/test/Dialect/EmitC/form-expressions.mlir
index 7b6723989e260..58eac4381ccb7 100644
--- a/mlir/test/Dialect/EmitC/form-expressions.mlir
+++ b/mlir/test/Dialect/EmitC/form-expressions.mlir
@@ -20,6 +20,25 @@ func.func @single_expression(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) ->
   return %c : i1
 }
 
+// CHECK-LABEL:   func.func @expression_recurring_args(
+// CHECK-SAME:      %[[ARG0:.*]]: i32,
+// CHECK-SAME:      %[[ARG1:.*]]: i32) -> i1 {
+// CHECK:           %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG1]], %[[ARG0]] : (i32, i32) -> i1 {
+// CHECK:             %[[VAL_0:.*]] = mul %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CHECK:             %[[VAL_1:.*]] = sub %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CHECK:             %[[VAL_2:.*]] = cmp lt, %[[VAL_1]], %[[ARG1]] : (i32, i32) -> i1
+// CHECK:             yield %[[VAL_2]] : i1
+// CHECK:           }
+// CHECK:           return %[[EXPRESSION_0]] : i1
+// CHECK:         }
+
+func.func @expression_recurring_args(%arg0: i32, %arg1: i32) -> i1 {
+  %a = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
+  %b = emitc.sub %a, %arg0 : (i32, i32) -> i32
+  %c = emitc.cmp lt, %b, %arg1 :(i32, i32) -> i1
+  return %c : i1
+}
+
 // CHECK-LABEL: func.func @multiple_expressions(
 // CHECK-SAME:      %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32, %[[VAL_3:.*]]: i32) -> (i32, i32) {
 // CHECK:         %[[VAL_4:.*]] = emitc.expression %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : (i32, i32, i32) -> i32 {
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index d1601bed29ca9..0d878e90cdf0c 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -379,6 +379,19 @@ emitc.func @test_expression_op_outside_expression() {
 
 // -----
 
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+  // expected-error @+1 {{'emitc.expression' with recurring operands expected block arguments}}
+  %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+    %a = emitc.rem %arg0, %arg1 : (i32, i32) -> i32
+    %b = emitc.add %a, %arg0 : (i32, i32) -> i32
+    %c = emitc.mul %b, %a : (i32, i32) -> i32
+    emitc.yield %c : i32
+  }
+  return %r : i32
+}
+
+// -----
+
 // expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}}
 emitc.func @multiple_results(%0: i32) -> (i32, i32) {
   emitc.return %0 : i32
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index b2c8b843ec14b..2f7544b5db096 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -1,5 +1,5 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
-// RUN: mlir-opt %s -canonicalize | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | FileCheck -check-prefix=CANON %s
 
 // CHECK: emitc.include <"test.h">
 // CHECK: emitc.include "test.h"
@@ -213,6 +213,28 @@ func.func @test_expression_multiple_uses(%arg0: i32, %arg1: i32) -> i32 {
   return %r : i32
 }
 
+// CANON-LABEL:   func.func @test_expression_recurring_operands(
+// CANON-SAME:      %[[ARG0:.*]]: i32,
+// CANON-SAME:      %[[ARG1:.*]]: i32) -> i32 {
+// CANON:           %[[EXPRESSION_0:.*]] = emitc.expression %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32 {
+// CANON:             %[[VAL_0:.*]] = rem %[[ARG0]], %[[ARG1]] : (i32, i32) -> i32
+// CANON:             %[[VAL_1:.*]] = add %[[VAL_0]], %[[ARG0]] : (i32, i32) -> i32
+// CANON:             %[[VAL_2:.*]] = mul %[[VAL_1]], %[[VAL_0]] : (i32, i32) -> i32
+// CANON:             yield %[[VAL_2]] : i32
+// CANON:           }
+// CANON:           return %[[EXPRESSION_0]] : i32
+// CANON:         }
+func.func @test_expression_recurring_operands(%arg0: i32, %arg1: i32) -> i32 {
+  %r = emitc.expression %arg0, %arg1, %arg0 : (i32, i32, i32) -> i32 {
+  ^bb0(%x: i32, %y: i32, %z: i32):
+    %a = emitc.rem %x, %y : (i32, i32) -> i32
+    %b = emitc.add %a, %z : (i32, i32) -> i32
+    %c = emitc.mul %b, %a : (i32, i32) -> i32
+    emitc.yield %c : i32
+  }
+  return %r : i32
+}
+
 func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
   emitc.for %i0 = %arg0 to %arg1 step %arg2 {
     %0 = emitc.call_opaque "func_const"(%i0) : (index) -> i32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir][EmitC] emitc.expression custom print/parse broken in common edge cases

2 participants