Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX] Remove NVPTX::IMAD opcode, and rely on intruction selection only #121724

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

peterbell10
Copy link
Contributor

@peterbell10 peterbell10 commented Jan 6, 2025

I noticed that NVPTX will sometimes emit mad.lo to multiply by 1, e.g. in https://gcc.godbolt.org/z/4j47Y9W4c.

This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified.

To fix this, I remove NVPTXISD::IMAD and only combine to mad during selection. This would allow the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention.

I noticed that NVPTX will sometimes emit `mad.lo` to multiply by 1,
e.g. in https://gcc.godbolt.org/z/45W3Wcnxz

This happens when DAGCombiner operates on the add before the mul, so
the imad contraction happens regardless of whether the mul could have
been simplified.

This PR adds:
```
mad x 1 y => add x y
mad x -1 y => sub y x
mad x 0 y => y
mad x y 0 => mul x y
mad c0 c1 z => add z (C0 * C1)
```

Another option might be to remove `NVPTXISD::IMAD` and only combine to
mad during selection. This would allow the normal DAGCombiner patterns
to simplify the graph without any NVPTX-specific intervention. However,
it also risks DAGCombiner breaking up the mul-add patterns, which
is why I haven't done it that way.
@peterbell10 peterbell10 marked this pull request as ready for review January 6, 2025 03:09
@llvmbot llvmbot added backend:NVPTX llvm:SelectionDAG SelectionDAGISel as well labels Jan 6, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 6, 2025

@llvm/pr-subscribers-backend-nvptx

@llvm/pr-subscribers-llvm-selectiondag

Author: None (peterbell10)

Changes

I noticed that NVPTX will sometimes emit mad.lo to multiply by 1, e.g. in https://gcc.godbolt.org/z/45W3Wcnxz.

This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified.

To fix this, I add some combiner patterns for IMAD. In particular, this PR adds:

mad x 1 y => add x y
mad x -1 y => sub y x
mad x 0 y => y
mad x y 0 => mul x y
mad c0 c1 z => add z (c0 * c1)

Another option might be to remove NVPTXISD::IMAD and only combine to mad during selection. This would allow the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention. However, it also risks DAGCombiner breaking up the mul-add patterns, which is why I haven't done it that way.

I found testing this change to be quite tricky as there is no mad intrinsic so we have to write add (mul x y) z in llvm IR which will be simplified by SelectionDAG::getNode before we even get the DAGCombiner. So, I've also added a couple of debug flags to disable default simplifications:

--selectiondag-simplify-nodes=false  # disables simplifications in SelectionDAG::getNode
--combiner-generic-combines=false    # disables generic DAGCombiner patterns

For the SelectionDAG flag there is a significant number of lines changed, but it's mostly removing duplicated code that had been copy-pasted to all of the getNode implementations and factoring it out into a private getNodeImpl so I think this is okay. I'd be happy to split this into a separate PR if required.


Full diff: https://github.com/llvm/llvm-project/pull/121724.diff

7 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+5)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+9-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+51-138)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+49)
  • (added) llvm/test/CodeGen/NVPTX/combine-mad-only.ll (+87)
  • (modified) llvm/test/CodeGen/NVPTX/combine-mad.ll (+20)
  • (modified) llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll (+1-1)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index ff7caec41855fd..3a015c8df2066a 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -2460,6 +2460,11 @@ class SelectionDAG {
   SDNode *FindNodeOrInsertPos(const FoldingSetNodeID &ID, const SDLoc &DL,
                               void *&InsertPos);
 
+  SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
+                      ArrayRef<SDValue> Ops, SDNodeFlags Flags);
+  SDValue getNodeImpl(unsigned Opcode, const SDLoc &DL, SDVTList VTs,
+                      ArrayRef<SDValue> Ops, SDNodeFlags Flags);
+
   /// Maps to auto-CSE operations.
   std::vector<CondCodeSDNode*> CondCodeNodes;
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6b2501591c81a3..6d75809cdaf69f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -153,6 +153,13 @@ static cl::opt<bool> EnableVectorFCopySignExtendRound(
     "combiner-vector-fcopysign-extend-round", cl::Hidden, cl::init(false),
     cl::desc(
         "Enable merging extends and rounds into FCOPYSIGN on vector types"));
+
+static cl::opt<bool>
+    EnableGenericCombines("combiner-generic-combines", cl::Hidden,
+                          cl::init(true),
+                          cl::desc("Enable generic DAGCombine patterns. Useful "
+                                   "for testing target-specific combines."));
+
 namespace {
 
   class DAGCombiner {
@@ -251,7 +258,8 @@ namespace {
         : DAG(D), TLI(D.getTargetLoweringInfo()),
           STI(D.getSubtarget().getSelectionDAGInfo()), OptLevel(OL), AA(AA) {
       ForCodeSize = DAG.shouldOptForSize();
-      DisableGenericCombines = STI && STI->disableGenericCombines(OptLevel);
+      DisableGenericCombines = !EnableGenericCombines ||
+                               (STI && STI->disableGenericCombines(OptLevel));
 
       MaximumLegalStoreInBits = 0;
       // We use the minimum store size here, since that's all we can guarantee
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 10e8ba93359fbd..6a3799e02edd94 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -115,6 +115,10 @@ static cl::opt<unsigned>
     MaxSteps("has-predecessor-max-steps", cl::Hidden, cl::init(8192),
              cl::desc("DAG combiner limit number of steps when searching DAG "
                       "for predecessor nodes"));
+static cl::opt<bool> EnableSimplifyNodes(
+    "selectiondag-simplify-nodes", cl::Hidden, cl::init(true),
+    cl::desc("Enable SelectionDAG::getNode simplifications. Useful for testing "
+             "DAG combines."));
 
 static void NewSDValueDbgMsg(SDValue V, StringRef Msg, SelectionDAG *G) {
   LLVM_DEBUG(dbgs() << Msg; V.getNode()->dump(G););
@@ -6157,23 +6161,46 @@ static SDValue foldCONCAT_VECTORS(const SDLoc &DL, EVT VT,
 }
 
 /// Gets or creates the specified node.
-SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
+SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL, EVT VT,
+                                  ArrayRef<SDValue> Ops,
+                                  const SDNodeFlags Flags) {
   SDVTList VTs = getVTList(VT);
-  FoldingSetNodeID ID;
-  AddNodeIDNode(ID, Opcode, VTs, {});
-  void *IP = nullptr;
-  if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP))
-    return SDValue(E, 0);
+  return getNodeImpl(Opcode, DL, VTs, Ops, Flags);
+}
 
-  auto *N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-  CSEMap.InsertNode(N, IP);
+SDValue SelectionDAG::getNodeImpl(unsigned Opcode, const SDLoc &DL,
+                                  SDVTList VTs, ArrayRef<SDValue> Ops,
+                                  const SDNodeFlags Flags) {
+  SDNode *N;
+  // Don't CSE glue-producing nodes
+  if (VTs.VTs[VTs.NumVTs - 1] != MVT::Glue) {
+    FoldingSetNodeID ID;
+    AddNodeIDNode(ID, Opcode, VTs, Ops);
+    void *IP = nullptr;
+    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
+      E->intersectFlagsWith(Flags);
+      return SDValue(E, 0);
+    }
+
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+    CSEMap.InsertNode(N, IP);
+  } else {
+    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
+    createOperands(N, Ops);
+  }
 
+  N->setFlags(Flags);
   InsertNode(N);
   SDValue V = SDValue(N, 0);
   NewSDValueDbgMsg(V, "Creating new node: ", this);
   return V;
 }
 
+SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT) {
+  return getNodeImpl(Opcode, DL, VT, {}, SDNodeFlags{});
+}
+
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               SDValue N1) {
   SDNodeFlags Flags;
@@ -6185,6 +6212,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
                               SDValue N1, const SDNodeFlags Flags) {
   assert(N1.getOpcode() != ISD::DELETED_NODE && "Operand is DELETED_NODE!");
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
 
   // Constant fold unary operations with a vector integer or float operand.
   switch (Opcode) {
@@ -6501,31 +6530,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
 
-  SDNode *N;
-  SDVTList VTs = getVTList(VT);
-  SDValue Ops[] = {N1};
-  if (VT != MVT::Glue) { // Don't CSE glue producing nodes
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    N->setFlags(Flags);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
-
-  InsertNode(N);
-  SDValue V = SDValue(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VT, {N1}, Flags);
 }
 
 static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
@@ -7219,6 +7224,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   assert(N1.getOpcode() != ISD::DELETED_NODE &&
          N2.getOpcode() != ISD::DELETED_NODE &&
          "Operand is DELETED_NODE!");
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
 
   canonicalizeCommutativeBinop(Opcode, N1, N2);
 
@@ -7665,32 +7672,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     }
   }
 
-  // Memoize this node if possible.
-  SDNode *N;
-  SDVTList VTs = getVTList(VT);
-  SDValue Ops[] = {N1, N2};
-  if (VT != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    N->setFlags(Flags);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
-
-  InsertNode(N);
-  SDValue V = SDValue(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VT, {N1, N2}, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -7708,6 +7690,9 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
          N2.getOpcode() != ISD::DELETED_NODE &&
          N3.getOpcode() != ISD::DELETED_NODE &&
          "Operand is DELETED_NODE!");
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
+
   // Perform various simplifications.
   switch (Opcode) {
   case ISD::FMA:
@@ -7862,33 +7847,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
   }
-
-  // Memoize node if it doesn't produce a glue result.
-  SDNode *N;
-  SDVTList VTs = getVTList(VT);
-  SDValue Ops[] = {N1, N2, N3};
-  if (VT != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    N->setFlags(Flags);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
-
-  InsertNode(N);
-  SDValue V = SDValue(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VT, {N1, N2, N3}, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
@@ -10343,6 +10302,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     assert(Op.getOpcode() != ISD::DELETED_NODE &&
            "Operand is DELETED_NODE!");
 #endif
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VT, Ops, Flags);
 
   switch (Opcode) {
   default: break;
@@ -10411,34 +10372,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   }
 
-  // Memoize nodes.
-  SDNode *N;
-  SDVTList VTs = getVTList(VT);
-
-  if (VT != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTs, Ops);
-    void *IP = nullptr;
-
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTs);
-    createOperands(N, Ops);
-  }
-
-  N->setFlags(Flags);
-  InsertNode(N);
-  SDValue V(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VT, Ops, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
@@ -10458,6 +10392,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
                               ArrayRef<SDValue> Ops, const SDNodeFlags Flags) {
   if (VTList.NumVTs == 1)
     return getNode(Opcode, DL, VTList.VTs[0], Ops, Flags);
+  if (!EnableSimplifyNodes)
+    return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
 
 #ifndef NDEBUG
   for (const auto &Op : Ops)
@@ -10622,30 +10558,7 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, SDVTList VTList,
 #endif
   }
 
-  // Memoize the node unless it returns a glue result.
-  SDNode *N;
-  if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) {
-    FoldingSetNodeID ID;
-    AddNodeIDNode(ID, Opcode, VTList, Ops);
-    void *IP = nullptr;
-    if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) {
-      E->intersectFlagsWith(Flags);
-      return SDValue(E, 0);
-    }
-
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
-    createOperands(N, Ops);
-    CSEMap.InsertNode(N, IP);
-  } else {
-    N = newSDNode<SDNode>(Opcode, DL.getIROrder(), DL.getDebugLoc(), VTList);
-    createOperands(N, Ops);
-  }
-
-  N->setFlags(Flags);
-  InsertNode(N);
-  SDValue V(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getNodeImpl(Opcode, DL, VTList, Ops, Flags);
 }
 
 SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..c4529c9151bc2b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5164,6 +5164,53 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
 }
 
+static SDValue
+PerformIMADCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, SDValue N2,
+                               TargetLowering::DAGCombinerInfo &DCI) {
+  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
+  ConstantSDNode *N2C = dyn_cast<ConstantSDNode>(N2);
+  EVT VT = N0->getValueType(0);
+  SDLoc DL(N);
+  SDNodeFlags Flags = N->getFlags();
+
+  // mad x 1 y => add x y
+  if (N1C && N1C->isOne())
+    return DCI.DAG.getNode(ISD::ADD, DL, VT, N0, N2, Flags);
+
+  // mad x -1 y => sub y x
+  if (N1C && N1C->isAllOnes()) {
+    Flags.setNoUnsignedWrap(false);
+    return DCI.DAG.getNode(ISD::SUB, DL, VT, N2, N0, Flags);
+  }
+
+  // mad x 0 y => y
+  if (N1C && N1C->isZero())
+    return N2;
+
+  // mad x y 0 => mul x y
+  if (N2C && N2C->isZero())
+    return DCI.DAG.getNode(ISD::MUL, DL, VT, N0, N1, Flags);
+
+  // mad c0 c1 x => add x (c0*c1)
+  if (SDValue C =
+          DCI.DAG.FoldConstantArithmetic(ISD::MUL, DL, VT, {N0, N1}, Flags))
+    return DCI.DAG.getNode(ISD::ADD, DL, VT, N2, C, Flags);
+
+  return {};
+}
+
+static SDValue PerformIMADCombine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  SDValue N2 = N->getOperand(2);
+  SDValue res = PerformIMADCombineWithOperands(N, N0, N1, N2, DCI);
+  if (res)
+    return res;
+
+  return PerformIMADCombineWithOperands(N, N1, N0, N2, DCI);
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5198,6 +5245,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformVSELECTCombine(N, DCI);
     case ISD::BUILD_VECTOR:
       return PerformBUILD_VECTORCombine(N, DCI);
+    case NVPTXISD::IMAD:
+      return PerformIMADCombine(N, DCI);
   }
   return SDValue();
 }
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad-only.ll b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
new file mode 100644
index 00000000000000..fb4bcc39b5a64d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/combine-mad-only.ll
@@ -0,0 +1,87 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | FileCheck %s
+; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -mtriple=nvptx -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -O1 --selectiondag-simplify-nodes=false --combiner-generic-combines=false --debug-counter=early-cse=100, | %ptxas-verify %}
+
+;; mad x 1 y => add y x
+define i32 @test_mad_mul_1(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_1_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_mul_1_param_1];
+; CHECK-NEXT:    add.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %x, 1
+  %add = add i32 %mul, %y
+  ret i32 %add
+}
+
+;; mad x -1 y => sub y x
+define i32 @test_mad_mul_neg_1(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_neg_1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_neg_1_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_mul_neg_1_param_1];
+; CHECK-NEXT:    sub.s32 %r3, %r2, %r1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %x, -1
+  %add = add i32 %mul, %y
+  ret i32 %add
+}
+
+;; mad x 0 y => y
+define i32 @test_mad_mul_0(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_mul_0(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_mul_0_param_1];
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %x, 0
+  %add = add i32 %mul, %y
+  ret i32 %add
+}
+
+;; mad x y 0 => mul x y
+define i32 @test_mad_add_0(i32 %x, i32 %y) {
+; CHECK-LABEL: test_mad_add_0(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_add_0_param_0];
+; CHECK-NEXT:    ld.param.u32 %r2, [test_mad_add_0_param_1];
+; CHECK-NEXT:    mul.lo.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 %x, %y
+  %add = add i32 %mul, 0
+  ret i32 %add
+}
+
+;; mad c0 c1 x => add x (c0*c1)
+define i32 @test_mad_fold_mul(i32 %x) {
+; CHECK-LABEL: test_mad_fold_mul(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_fold_mul_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 12;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %mul = mul i32 4, 3
+  %add = add i32 %mul, %x
+  ret i32 %add
+}
diff --git a/llvm/test/CodeGen/NVPTX/combine-mad.ll b/llvm/test/CodeGen/NVPTX/combine-mad.ll
index 1b22cfde39725f..7d523a835a1f3f 100644
--- a/llvm/test/CodeGen/NVPTX/combine-mad.ll
+++ b/llvm/test/CodeGen/NVPTX/combine-mad.ll
@@ -183,3 +183,23 @@ define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
   %add = add i32 %c, %sel
   ret i32 %add
 }
+
+;; This case relies on mad x 1 y => add x y, previously we emit:
+;;     mad.lo.s32      %r3, %r1, 1, %r2;
+define i32 @test_mad_fold(i32 %x) {
+; CHECK-LABEL: test_mad_fold(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.u32 %r1, [test_mad_fold_param_0];
+; CHECK-NEXT:    mul.hi.s32 %r2, %r1, -2147221471;
+; CHECK-NEXT:    add.s32 %r3, %r1, %r2;
+; CHECK-NEXT:    shr.u32 %r4, %r3, 31;
+; CHECK-NEXT:    shr.s32 %r5, %r3, 12;
+; CHECK-NEXT:    add.s32 %r6, %r5, %r4;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r6;
+; CHECK-NEXT:    ret;
+  %div = sdiv i32 %x, 8191
+  ret i32 %div
+}
diff --git a/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll b/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
index 27a523b9dd91d2..de19d2983f3435 100644
--- a/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
+++ b/llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll
@@ -12,7 +12,7 @@
 ; CHECK-NOT: __local_depot
 
 ; CHECK-32:       ld.param.u32  %r[[SIZE:[0-9]]], [test_dynamic_stackalloc_param_0];
-; CHECK-32-NEXT:  mad.lo.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 1, 7;
+; CHECK-32-NEXT:  add.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 7;
 ; CHECK-32-NEXT:  and.b32         %r[[SIZE3:[0-9]]], %r[[SIZE2]], -8;
 ; CHECK-32-NEXT:  alloca.u32  %r[[ALLOCA:[0-9]]], %r[[SIZE3]], 16;
 ; CHECK-32-NEXT:  cvta.local.u32  %r[[ALLOCA]], %r[[ALLOCA]];

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option might be to remove NVPTXISD::IMAD and only combine to mad during selection. This would allow the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention.

This would be better without a stronger justification for why an intermediate node is useful

However, it also risks DAGCombiner breaking up the mul-add patterns, which is why I haven't done it that way.

Should have a more concrete justification. Generally combines that would break the pattern are new instances of the pattern to handle (e.g. multiply-to-shift + add)

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp Outdated Show resolved Hide resolved
@peterbell10
Copy link
Contributor Author

peterbell10 commented Jan 6, 2025

This would be better without a stronger justification for why an intermediate node is useful

Perhaps the NVPTX maintainers know the history better here.

Generally combines that would break the pattern are new instances of the pattern to handle (e.g. multiply-to-shift + add)

That's a very relevant example. multiply-to-shift is a pessimisation if it breaks up a mad, since mad.lo has the same throughput as shift, but is able to fuse the additional add in the same instruction.

Copy link

github-actions bot commented Jan 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@peterbell10 peterbell10 changed the title [NVPTX][SelectionDAG] Add IMAD combine rules + infra to disable default SelectionDAG rules for testing [NVPTX] Remove NVPTX::IMAD opcode, and rely on intruction selection only Jan 7, 2025
@peterbell10
Copy link
Contributor Author

Okay, I've reworked this to remove NVPTX::IMAD and instead rely on tablegen patterns to fuse the operations. I looked back in the history and it seems it was just always done that way since mad was added
eafe26d#diff-cd5ad648f0fbe25520e97c78f6f302492b4d713effce854ec6393bb39bb58ded

Even today there are only one or two references to IMAD in combine patterns, so I think this is likely fine. The i128 lit test even seems to be doing a better job of generating mad now.

@peterbell10 peterbell10 removed the llvm:SelectionDAG SelectionDAGISel as well label Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants