Skip to content

Commit

Permalink
Add Optional Redundant Clip Node to NodeUnit (#22888)
Browse files Browse the repository at this point in the history
Currently we have Clip/Relu with Q fusion on level 2. But for EPs that
are using NodeUnit, these optimizers are not applied. If we want to
remove such redundant Clip/Relu nodes, we need to add code to handle it
for each EP separately.

The PR detects a Clip/Relu is made redundant with a Q node, and add this
information to the corresponding QDQ NodeUnit, so that EPs can ignore
it, and can handle the target node only in the QDQ NodeUnit.
  • Loading branch information
centwang authored Jan 9, 2025
1 parent ca77de5 commit 4134cd9
Show file tree
Hide file tree
Showing 13 changed files with 748 additions and 664 deletions.
42 changes: 37 additions & 5 deletions onnxruntime/core/framework/node_unit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ std::vector<NodeUnitIODef> GetQDQIODefs(const Node& target_node, const QDQ::Node

Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,
const Node& target_node,
const Node* redundant_clip_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes) {
// Within a QDQ node group, a target node input is the only consumer of each DQ.
Expand All @@ -176,6 +177,24 @@ Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,
dq_node->Name(), ", target node: ", target_node.Name());
}

// If redundant_clip_node is present, currently we require target node has only one output edge, which is connected to
// the redundant_clip_node. The redundant_clip_node's output is consumed by the Q node that can be fused with itself.
if (redundant_clip_node) {
ORT_RETURN_IF_NOT(!graph_viewer.NodeProducesGraphOutput(target_node) && target_node.GetOutputEdgesCount() == 1 &&
target_node.OutputEdgesBegin()->GetNode().Index() == redundant_clip_node->Index(),
"QDQ node group cannot have target node with more than one output edge if there is redunant clip "
"node. target node: ",
target_node.Name());
ORT_RETURN_IF_NOT(
!graph_viewer.NodeProducesGraphOutput(*redundant_clip_node) && q_nodes.size() == 1 &&
redundant_clip_node->GetOutputEdgesCount() == 1 &&
redundant_clip_node->OutputEdgesBegin()->GetNode().Index() == q_nodes[0]->Index(),
"QDQ node group cannot have redudant clip node that doesn't have a single output edge to a Q node. "
"redundant clip node: ",
redundant_clip_node->Name());
return Status::OK();
}

// an output from the target node can have either Q consumers or direct consumers. it cannot have both.
// this must be checked on a per output basis.
// e.g. TopK produces values and indices. The indices output won't be quantized, so even if we replace the TopK QDQ
Expand Down Expand Up @@ -228,8 +247,10 @@ Status QDQ::NodeGroup::CanCreateNodeGroup(const GraphViewer& graph_viewer,

return Status::OK();
}

NodeUnit::NodeUnit(const Node& node)
: target_node_(node),
redundant_clip_node_(nullptr),
type_(Type::SingleNode),
input_edge_count_(node.GetInputEdgesCount()) {
InitForSingleNode();
Expand All @@ -238,11 +259,16 @@ NodeUnit::NodeUnit(const Node& node)
NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group)
: dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)},
target_node_(*graph_viewer.GetNode(node_group.target_node)),
redundant_clip_node_(node_group.redundant_clip_node.has_value()
? graph_viewer.GetNode(node_group.redundant_clip_node.value())
: nullptr),
q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)},
type_(Type::QDQGroup),
inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)},
outputs_{GetQDQIODefs(target_node_, node_group, false /* is_input */)} {
ORT_THROW_IF_ERROR(QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, dq_nodes_, q_nodes_));
outputs_{GetQDQIODefs((redundant_clip_node_ ? *redundant_clip_node_ : target_node_), node_group,
false /* is_input */)} {
ORT_THROW_IF_ERROR(
QDQ::NodeGroup::CanCreateNodeGroup(graph_viewer, target_node_, redundant_clip_node_, dq_nodes_, q_nodes_));

input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0),
[](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); });
Expand All @@ -253,8 +279,10 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g

// create output edges. each target node output either goes to Q node/s or non-Q node/s.
// ValidateNodeGroupQDQNodes ensures this.
auto cur_edge = target_node_.OutputEdgesBegin();
auto end_edge = target_node_.OutputEdgesEnd();
// If redundant clip node is present, the target node has only one output edge to the redundant clip node.
const Node& output_producer = redundant_clip_node_ ? *redundant_clip_node_ : target_node_;
auto cur_edge = output_producer.OutputEdgesBegin();
auto end_edge = output_producer.OutputEdgesEnd();
for (; cur_edge != end_edge; ++cur_edge) {
const Node& node = cur_edge->GetNode();

Expand All @@ -273,12 +301,13 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g
}
}

NodeUnit::NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node,
NodeUnit::NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node, const Node* redundant_clip_node,
gsl::span<const Node* const> q_nodes, Type unit_type,
gsl::span<const NodeUnitIODef> inputs, gsl::span<const NodeUnitIODef> outputs,
size_t input_edge_count, Node::EdgeSet output_edges)
: dq_nodes_(dq_nodes.begin(), dq_nodes.end()),
target_node_(target_node),
redundant_clip_node_(redundant_clip_node),
q_nodes_(q_nodes.begin(), q_nodes.end()),
type_(unit_type),
inputs_(inputs.begin(), inputs.end()),
Expand Down Expand Up @@ -389,6 +418,9 @@ Node::EdgeConstIterator NodeUnit::OutputEdgesEnd() const {
std::vector<const Node*> NodeUnit::GetAllNodesInGroup() const noexcept {
std::vector<const Node*> all_nodes = dq_nodes_;
all_nodes.push_back(&target_node_);
if (redundant_clip_node_) {
all_nodes.push_back(redundant_clip_node_);
}
all_nodes.insert(all_nodes.end(), q_nodes_.begin(), q_nodes_.end());
return all_nodes;
}
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/framework/node_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ struct NodeGroup {
std::vector<NodeIndex> dq_nodes;
std::vector<NodeIndex> q_nodes;
NodeIndex target_node;
std::optional<NodeIndex> redundant_clip_node;

// Validator to check if the set of nodes can form a valid QDQ NodeGroup.
// Checks target node is only consumer of each DQ, and that the outputs remain valid if the QDQ node group was to
// be converted into a single node with a quantized operator.
static Status CanCreateNodeGroup(const GraphViewer& graph_viewer,
const Node& target_node,
const Node* redundant_clip_node,
gsl::span<const Node* const> dq_nodes,
gsl::span<const Node* const> q_nodes);
};
Expand Down Expand Up @@ -68,7 +70,7 @@ class NodeUnit {
public:
explicit NodeUnit(const Node& node);
explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group);
NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node,
NodeUnit(gsl::span<const Node* const> dq_nodes, const Node& target_node, const Node* redundant_clip_node,
gsl::span<const Node* const> q_nodes, Type unit_type,
gsl::span<const NodeUnitIODef> inputs, gsl::span<const NodeUnitIODef> outputs,
size_t input_edge_count, Node::EdgeSet output_edges);
Expand All @@ -87,6 +89,7 @@ class NodeUnit {
ProviderType GetExecutionProviderType() const noexcept;

const Node& GetNode() const noexcept { return target_node_; }
const Node* GetRedundantClipNode() const noexcept { return redundant_clip_node_; }
const std::vector<const Node*>& GetDQNodes() const noexcept { return dq_nodes_; }
const std::vector<const Node*>& GetQNodes() const noexcept { return q_nodes_; }
std::vector<const Node*> GetAllNodesInGroup() const noexcept;
Expand All @@ -106,6 +109,7 @@ class NodeUnit {

const std::vector<const Node*> dq_nodes_; // dq nodes for this NodeUnit, not necessarily all inputs
const Node& target_node_;
const Node* redundant_clip_node_; // Optional redundant clip node for the QDQ group, nullptr if not present.
const std::vector<const Node*> q_nodes_; // q-nodes for this NodeUnit. not necessarily all outputs
const Type type_;

Expand Down
126 changes: 126 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,130 @@ bool MatchDQNode(const Node& node) {

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

namespace {

bool GetDataTypeMinMax(int32_t data_type, int32_t& min, int32_t& max) {
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::INT8:
min = static_cast<int32_t>(std::numeric_limits<int8_t>::min());
max = static_cast<int32_t>(std::numeric_limits<int8_t>::max());
break;
case ONNX_NAMESPACE::TensorProto::UINT8:
min = static_cast<int32_t>(std::numeric_limits<uint8_t>::min());
max = static_cast<int32_t>(std::numeric_limits<uint8_t>::max());
break;
case ONNX_NAMESPACE::TensorProto::INT16:
min = static_cast<int32_t>(std::numeric_limits<int16_t>::min());
max = static_cast<int32_t>(std::numeric_limits<int16_t>::max());
break;
case ONNX_NAMESPACE::TensorProto::UINT16:
min = static_cast<int32_t>(std::numeric_limits<uint16_t>::min());
max = static_cast<int32_t>(std::numeric_limits<uint16_t>::max());
break;
default:
return false;
}
return true;
}
bool GetQScalarScaleZp(const Graph& graph, const Node& q_node, float& scale, int32_t& zp, int32_t& data_type) {
assert(q_node.OpType() == QOpName);
const auto& q_input_defs = q_node.InputDefs();

const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = graph.GetConstantInitializer(q_input_defs[1]->Name(), true);
if (!scale_tensor_proto) {
return false;
}

// Support scalar float scale only for now. Need to extend to other float types if needed.
Initializer scale_initializer(*scale_tensor_proto, graph.ModelPath());
if (scale_initializer.dims().size() != 0 || scale_initializer.data_type() != ONNX_NAMESPACE::TensorProto::FLOAT) {
return false;
}

scale = *scale_initializer.data<float>();

if (q_input_defs.size() != 3 || !q_input_defs[2]->Exists()) {
int32_t output_dtype = ONNX_NAMESPACE::TensorProto::UNDEFINED;
const auto& q_attrs = q_node.GetAttributes();
if (auto it = q_attrs.find("output_dtype"); it != q_attrs.end()) {
output_dtype = static_cast<int32_t>(it->second.i());
}

data_type =
output_dtype == ONNX_NAMESPACE::TensorProto::UNDEFINED ? ONNX_NAMESPACE::TensorProto::UINT8 : output_dtype;
zp = 0;
return true;
}

const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = graph.GetConstantInitializer(q_input_defs[2]->Name(), true);
if (!zp_tensor_proto) {
return false;
}

Initializer zp_initializer(*zp_tensor_proto, graph.ModelPath());
if (zp_initializer.dims().size() != 0) {
return false;
}

data_type = zp_initializer.data_type();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::INT8:
zp = static_cast<int32_t>(*zp_initializer.data<int8_t>());
break;
case ONNX_NAMESPACE::TensorProto::UINT8:
zp = static_cast<int32_t>(*zp_initializer.data<uint8_t>());
break;
case ONNX_NAMESPACE::TensorProto::INT16:
zp = static_cast<int32_t>(*zp_initializer.data<int16_t>());
break;
case ONNX_NAMESPACE::TensorProto::UINT16:
zp = static_cast<int32_t>(*zp_initializer.data<uint16_t>());
break;
default:
return false;
}

return true;
}

} // namespace

bool IsClipMadeRedundantByQ(const Graph& graph, const Node& clip_node, const Node& q_node) {
float scale = 0.0f;
int32_t zp = 0;
int32_t data_type = 0;
if (!GetQScalarScaleZp(graph, q_node, scale, zp, data_type)) {
return false;
}

int32_t data_type_min = 0;
int32_t data_type_max = 0;
if (!GetDataTypeMinMax(data_type, data_type_min, data_type_max)) {
return false;
}

const std::string& clip_op_type = clip_node.OpType();
if (clip_op_type == "Relu") {
return zp == data_type_min;
}

if (clip_op_type == "Clip") {
float clip_min = 0.0f;
float clip_max = 0.0f;
if (!optimizer_utils::GetClipConstantMinMax(graph, clip_node, clip_min, clip_max)) {
return false;
}

int32_t q_clip_min = static_cast<int32_t>(::rint(clip_min / scale)) + zp;
int32_t q_clip_max = static_cast<int32_t>(::rint(clip_max / scale)) + zp;

// The Clip can be removed if its range entirely overlaps the quantization range.
// QClip range: [------------------]
// Quant range: [-------------]
return q_clip_min <= data_type_min && q_clip_max >= data_type_max;
}

return false;
}

} // namespace onnxruntime::QDQ
5 changes: 5 additions & 0 deletions onnxruntime/core/optimizer/qdq_transformer/qdq_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace onnxruntime {

class Node;
class Path;
class Graph;

namespace QDQ {

Expand Down Expand Up @@ -76,5 +77,9 @@ bool MatchQNode(const Node& node);
// Check DQ node op type, version, and domain.
bool MatchDQNode(const Node& node);
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

// Check if an clip node is made redundant by Q node.
bool IsClipMadeRedundantByQ(const Graph& graph, const Node& clip_node, const Node& q_node);

} // namespace QDQ
} // namespace onnxruntime
Loading

0 comments on commit 4134cd9

Please sign in to comment.