Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstSimplify] Use multi-op replacement when simplify select #121708

Merged
merged 5 commits into from
Jan 7, 2025

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Jan 5, 2025

  • [InstSimplify] Refactor simplifyWithOpsReplaced to allow multiple replacements; NFC
  • [InstSimplify] Use multi-op replacement when simplify select

In the case of select X | Y == 0 :... or select X & Y == -1 : ...
we can do more simplifications by trying to replace both X and Y
with the respective constant at once.

Handles some cases for #121672 more generically.

@llvmbot
Copy link
Member

llvmbot commented Jan 5, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

Changes
  • [InstSimplify] Refactor simplifyWithOpsReplaced to allow multiple replacements; NFC
  • [InstSimplify] Use multi-op replacement when simplify select

Full diff: https://github.com/llvm/llvm-project/pull/121708.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+75-48)
  • (modified) llvm/test/Transforms/InstCombine/select.ll (+10-25)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 8567a0504f54e1..aba71c5355ad46 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4275,25 +4275,23 @@ Value *llvm::simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
   return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
 }
 
-static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                                     const SimplifyQuery &Q,
-                                     bool AllowRefinement,
-                                     SmallVectorImpl<Instruction *> *DropFlags,
-                                     unsigned MaxRecurse) {
+static Value *simplifyWithOpsReplaced(Value *V,
+                                      ArrayRef<std::pair<Value *, Value *>> Ops,
+                                      const SimplifyQuery &Q,
+                                      bool AllowRefinement,
+                                      SmallVectorImpl<Instruction *> *DropFlags,
+                                      unsigned MaxRecurse) {
   assert((AllowRefinement || !Q.CanUseUndef) &&
          "If AllowRefinement=false then CanUseUndef=false");
-
-  // Trivial replacement.
-  if (V == Op)
-    return RepOp;
+  for (const auto &OpAndRepOp : Ops) {
+    // Trivial replacement.
+    if (V == OpAndRepOp.first)
+      return OpAndRepOp.second;
+  }
 
   if (!MaxRecurse--)
     return nullptr;
 
-  // We cannot replace a constant, and shouldn't even try.
-  if (isa<Constant>(Op))
-    return nullptr;
-
   auto *I = dyn_cast<Instruction>(V);
   if (!I)
     return nullptr;
@@ -4303,11 +4301,6 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   if (isa<PHINode>(I))
     return nullptr;
 
-  // For vector types, the simplification must hold per-lane, so forbid
-  // potentially cross-lane operations like shufflevector.
-  if (Op->getType()->isVectorTy() && !isNotCrossLaneOperation(I))
-    return nullptr;
-
   // Don't fold away llvm.is.constant checks based on assumptions.
   if (match(I, m_Intrinsic<Intrinsic::is_constant>()))
     return nullptr;
@@ -4316,12 +4309,30 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   if (isa<FreezeInst>(I))
     return nullptr;
 
+  SmallVector<std::pair<Value *, Value *>> ValidReplacements{};
+  for (const auto &OpAndRepOp : Ops) {
+    // We cannot replace a constant, and shouldn't even try.
+    if (isa<Constant>(OpAndRepOp.first))
+      return nullptr;
+
+    // For vector types, the simplification must hold per-lane, so forbid
+    // potentially cross-lane operations like shufflevector.
+    if (OpAndRepOp.first->getType()->isVectorTy() &&
+        !isNotCrossLaneOperation(I))
+      continue;
+    ValidReplacements.emplace_back(OpAndRepOp);
+  }
+
+  if (ValidReplacements.empty())
+    return nullptr;
+
   // Replace Op with RepOp in instruction operands.
   SmallVector<Value *, 8> NewOps;
   bool AnyReplaced = false;
   for (Value *InstOp : I->operands()) {
-    if (Value *NewInstOp = simplifyWithOpReplaced(
-            InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) {
+    if (Value *NewInstOp =
+            simplifyWithOpsReplaced(InstOp, ValidReplacements, Q,
+                                    AllowRefinement, DropFlags, MaxRecurse)) {
       NewOps.push_back(NewInstOp);
       AnyReplaced = InstOp != NewInstOp;
     } else {
@@ -4372,7 +4383,9 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
       // by assumption and this case never wraps, so nowrap flags can be
       // ignored.
       if ((Opcode == Instruction::Sub || Opcode == Instruction::Xor) &&
-          NewOps[0] == RepOp && NewOps[1] == RepOp)
+          any_of(ValidReplacements, [=](const auto &Rep) {
+            return NewOps[0] == NewOps[1] && NewOps[0] == Rep.second;
+          }))
         return Constant::getNullValue(I->getType());
 
       // If we are substituting an absorber constant into a binop and extra
@@ -4382,10 +4395,10 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
       // (Op == 0) ? 0 : (Op & -Op)            --> Op & -Op
       // (Op == 0) ? 0 : (Op * (binop Op, C))  --> Op * (binop Op, C)
       // (Op == -1) ? -1 : (Op | (binop C, Op) --> Op | (binop C, Op)
-      Constant *Absorber =
-          ConstantExpr::getBinOpAbsorber(Opcode, I->getType());
+      Constant *Absorber = ConstantExpr::getBinOpAbsorber(Opcode, I->getType());
       if ((NewOps[0] == Absorber || NewOps[1] == Absorber) &&
-          impliesPoison(BO, Op))
+          any_of(ValidReplacements,
+                 [=](const auto &Rep) { return impliesPoison(BO, Rep.first); }))
         return Absorber;
     }
 
@@ -4453,6 +4466,15 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                                   /*AllowNonDeterministic=*/false);
 }
 
+static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                                     const SimplifyQuery &Q,
+                                     bool AllowRefinement,
+                                     SmallVectorImpl<Instruction *> *DropFlags,
+                                     unsigned MaxRecurse) {
+  return simplifyWithOpsReplaced(V, {{Op, RepOp}}, Q, AllowRefinement,
+                                 DropFlags, MaxRecurse);
+}
+
 Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                                     const SimplifyQuery &Q,
                                     bool AllowRefinement,
@@ -4595,17 +4617,16 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
 
 /// Try to simplify a select instruction when its condition operand is an
 /// integer equality or floating-point equivalence comparison.
-static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS,
-                                            Value *TrueVal, Value *FalseVal,
-                                            const SimplifyQuery &Q,
-                                            unsigned MaxRecurse) {
-  if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(),
-                             /* AllowRefinement */ false,
-                             /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
+static Value *simplifySelectWithEquivalence(
+    ArrayRef<std::pair<Value *, Value *>> Replacements, Value *TrueVal,
+    Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) {
+  if (simplifyWithOpsReplaced(FalseVal, Replacements, Q.getWithoutUndef(),
+                              /* AllowRefinement */ false,
+                              /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
     return FalseVal;
-  if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
-                             /* AllowRefinement */ true,
-                             /* DropFlags */ nullptr, MaxRecurse) == FalseVal)
+  if (simplifyWithOpsReplaced(TrueVal, Replacements, Q,
+                              /* AllowRefinement */ true,
+                              /* DropFlags */ nullptr, MaxRecurse) == FalseVal)
     return FalseVal;
 
   return nullptr;
@@ -4699,10 +4720,10 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   // the arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
   if (Pred == ICmpInst::ICMP_EQ) {
-    if (Value *V = simplifySelectWithEquivalence(CmpLHS, CmpRHS, TrueVal,
+    if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, TrueVal,
                                                  FalseVal, Q, MaxRecurse))
       return V;
-    if (Value *V = simplifySelectWithEquivalence(CmpRHS, CmpLHS, TrueVal,
+    if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, TrueVal,
                                                  FalseVal, Q, MaxRecurse))
       return V;
 
@@ -4712,11 +4733,14 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
         match(CmpRHS, m_Zero())) {
       // (X | Y) == 0 implies X == 0 and Y == 0.
-      if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
-                                                   Q, MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(
+              {{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse))
+        return V;
+      if (Value *V = simplifySelectWithEquivalence({{X, CmpRHS}}, TrueVal,
+                                                   FalseVal, Q, MaxRecurse))
         return V;
-      if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
-                                                   Q, MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence({{Y, CmpRHS}}, TrueVal,
+                                                   FalseVal, Q, MaxRecurse))
         return V;
     }
 
@@ -4724,11 +4748,14 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
     if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) &&
         match(CmpRHS, m_AllOnes())) {
       // (X & Y) == -1 implies X == -1 and Y == -1.
-      if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
-                                                   Q, MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence(
+              {{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse))
         return V;
-      if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
-                                                   Q, MaxRecurse))
+      if (Value *V = simplifySelectWithEquivalence({{X, CmpRHS}}, TrueVal,
+                                                   FalseVal, Q, MaxRecurse))
+        return V;
+      if (Value *V = simplifySelectWithEquivalence({{Y, CmpRHS}}, TrueVal,
+                                                   FalseVal, Q, MaxRecurse))
         return V;
     }
   }
@@ -4757,11 +4784,11 @@ static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F,
   // This transforms is safe if at least one operand is known to not be zero.
   // Otherwise, the select can change the sign of a zero operand.
   if (IsEquiv) {
-    if (Value *V =
-            simplifySelectWithEquivalence(CmpLHS, CmpRHS, T, F, Q, MaxRecurse))
+    if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, T, F, Q,
+                                                 MaxRecurse))
       return V;
-    if (Value *V =
-            simplifySelectWithEquivalence(CmpRHS, CmpLHS, T, F, Q, MaxRecurse))
+    if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, T, F, Q,
+                                                 MaxRecurse))
       return V;
   }
 
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 0168a804239a89..85637b27d8adcf 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -3937,11 +3937,8 @@ entry:
 define i32 @src_or_eq_0_and_xor(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_or_eq_0_and_xor(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[XOR]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[XOR]]
 ;
 entry:
   %or = or i32 %y, %x
@@ -3956,11 +3953,8 @@ entry:
 define i32 @src_or_eq_0_xor_and(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_or_eq_0_xor_and(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[AND]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[AND]]
 ;
 entry:
   %or = or i32 %y, %x
@@ -4438,11 +4432,8 @@ define i32 @src_no_trans_select_and_eq0_xor_and(i32 %x, i32 %y) {
 
 define i32 @src_no_trans_select_or_eq0_or_and(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_no_trans_select_or_eq0_or_and(
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[OR0:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[OR0]], i32 0, i32 [[AND]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    [[AND:%.*]] = and i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i32 [[AND]]
 ;
   %or = or i32 %x, %y
   %or0 = icmp eq i32 %or, 0
@@ -4453,11 +4444,8 @@ define i32 @src_no_trans_select_or_eq0_or_and(i32 %x, i32 %y) {
 
 define i32 @src_no_trans_select_or_eq0_or_xor(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_no_trans_select_or_eq0_or_xor(
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[OR0:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[OR0]], i32 0, i32 [[XOR]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i32 [[XOR]]
 ;
   %or = or i32 %x, %y
   %or0 = icmp eq i32 %or, 0
@@ -4492,11 +4480,8 @@ define i32 @src_no_trans_select_or_eq0_xor_or(i32 %x, i32 %y) {
 
 define i32 @src_no_trans_select_and_ne0_xor_or(i32 %x, i32 %y) {
 ; CHECK-LABEL: @src_no_trans_select_and_ne0_xor_or(
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[OR0_NOT:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[OR0_NOT]], i32 0, i32 [[XOR]]
-; CHECK-NEXT:    ret i32 [[COND]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret i32 [[XOR]]
 ;
   %or = or i32 %x, %y
   %or0 = icmp ne i32 %or, 0

@goldsteinn goldsteinn changed the title goldsteinn/simplify multiple ops [InstSimplify] Use multi-op replacement when simplify select Jan 5, 2025
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please check the compile-time impact?

llvm/lib/Analysis/InstructionSimplify.cpp Outdated Show resolved Hide resolved
@goldsteinn
Copy link
Contributor Author

Can you please check the compile-time impact?

Compile time data: https://llvm-compile-time-tracker.com/compare.php?from=c56b74315f57acb1b285ddc77b07031b773549b7&to=0c2ed79c43b8f7833566172e691dc702a1a2a5aa&stat=instructions%3Au

Maybe a bit on kimwitu++ 52485M 52624M (+0.27%), but other than that seems fine (should also improve without the extra calls after rebasing on your patch).

// potentially cross-lane operations like shufflevector.
if (OpAndRepOp.first->getType()->isVectorTy() &&
!isNotCrossLaneOperation(I))
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just return nullptr here as well. Handling this more precisely is not worth constructing the separate vector. (In practice all the types are the same anyway.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we need to create second vec anyways for const case (which should be continue as opposed to return).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The const case should also return (as you already do).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think thats a mistake. If say we have (select (icmp eq (or x, y), 0), (add x, 1), ...), we don't want to bail on simplifying the add just because its rhs is constant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not what this check does. It protects against doing something like "replace 0 with x".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we bail when any op is constant we will end up bailing on the entire simplification no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, yeah I had it the other way around. I guess we are guarding against (icmp eq (or x, 1), 0) which will get simplification elsewhere. Will change.

@@ -4372,7 +4383,9 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// by assumption and this case never wraps, so nowrap flags can be
// ignored.
if ((Opcode == Instruction::Sub || Opcode == Instruction::Xor) &&
NewOps[0] == RepOp && NewOps[1] == RepOp)
any_of(ValidReplacements, [=](const auto &Rep) {
return NewOps[0] == NewOps[1] && NewOps[0] == Rep.second;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can move NewOps[0] == NewOps[1] out of the any_of.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a test for the & == -1 case?

for (const auto &OpAndRepOp : Ops) {
// We cannot replace a constant, and shouldn't even try.
if (isa<Constant>(OpAndRepOp.first))
return nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this is pre-existing, but I do wonder why this is checked here and not before the trivial replacement at the top.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ill move.

…eplacements; NFC

This allows for multiple ops to be replaced at once. This can be
useful if say replacing in the context of `X | Y == 0`, where some
replacements will only be available if we replace `X` and `Y` with `0`
at the same time.
In the case of `select X | Y == 0 :...` or `select X & Y == -1 : ...`
we can do more simplifications by trying to replace both `X` and `Y`
with the respective constant at once.

Handles some cases for llvm#121672 more generically.
@goldsteinn goldsteinn force-pushed the goldsteinn/simplify-multiple-ops branch from fed5ee2 to 4d864c7 Compare January 6, 2025 22:27
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. No hits on llvm-opt-benchmark, but the generalization of simplifyWithOpReplaced itself is pretty reasonable and could be more useful in the future (e.g. to handle and/or of icmp in the select condition).

@goldsteinn goldsteinn merged commit 6192faf into llvm:main Jan 7, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants