Skip to content

Commit

Permalink
[Feat.] flow.reshape support Gradient Accumulation and Scalar (#6254)
Browse files Browse the repository at this point in the history
* ReshapeFunctor remark dim 0 as -1 in lazy dim 0 no changed

* Reshape op infer support dim-0 unchanged

* add error note

* Add test graph reshape acc script

* add test reshape scalar

* remove print

* refine error log

* Replace reshape conf in grad acc pass (#6268)

Co-authored-by: Yao Chi <[email protected]>
Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2021
1 parent ed611d3 commit 74775c1
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 11 deletions.
6 changes: 4 additions & 2 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,15 @@ class ReshapeFunctor {
size_t x_count = x->shape()->Count(0);
MutableAttrMap attrs;
if (need_infer_axis == -1) {
CHECK_EQ_OR_RETURN(shape.Count(0), x_count);
CHECK_EQ_OR_RETURN(shape.Count(0), x_count)
<< "\n Shape " << shape.ToString() << " is invalid for input shape "
<< x->shape()->ToString();
JUST(attrs.SetAttr<Shape>("shape", shape));
} else {
Shape infered_shape = shape;
infered_shape.Set(need_infer_axis, x_count / count);
CHECK_EQ_OR_RETURN(infered_shape.Count(0), x_count)
<< "Shape " << shape.ToString() << " is invalid for input of shape "
<< "\n Shape " << shape.ToString() << " is invalid for input shape "
<< x->shape()->ToString();
JUST(attrs.SetAttr<Shape>("shape", infered_shape));
}
Expand Down
20 changes: 20 additions & 0 deletions oneflow/core/job_rewriter/gradient_accumulation_rewrite_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,26 @@ Maybe<void> GradientAccumulationRewritePass::Apply(Job* job, JobPassCtx* ctx) co
return_pack_op.output("out", 0));
CHECK_EQ(return_in_lbn, old_val);
return Maybe<void>::Ok();
} else if (op_conf.has_user_conf() && op_conf.user_conf().op_type_name() == "reshape") {
const LogicalBlobId in_lbi = node->op().BnInOp2Lbi(node->op().SoleIbn());
const LogicalBlobId out_lbi = node->op().BnInOp2Lbi(node->op().SoleObn());
const Shape& in_shape = node->LogicalBlobDesc4Lbi(in_lbi).shape();
const Shape& out_shape = node->LogicalBlobDesc4Lbi(out_lbi).shape();
if (in_shape.NumAxes() > 0 && out_shape.NumAxes() > 0 && in_shape.At(0) == out_shape.At(0)) {
// NOTE(chengcheng):
// in nn.Graph GradientAccumulation, the reshape conf in JobBuild and after insert
// acc/unpack maybe NOT equal because of dim 0 scaled, so need set dim 0 as -1 for
// dynamic infer.
OperatorConf* new_reshape_op_conf = GetOperatorConf4Modify(op_conf);
AttrValue* attr_val = &(*new_reshape_op_conf->mutable_user_conf()->mutable_attr())["shape"];
CHECK(attr_val->has_at_shape());
ShapeProto* shape_conf = attr_val->mutable_at_shape();
CHECK_GT(shape_conf->dim_size(), 0);
shape_conf->set_dim(0, -1);
LOG(INFO) << " Replace ReshapeOpConf from: " << op_conf.DebugString() << " to "
<< new_reshape_op_conf->DebugString() << " for dynamic infer by insert unpack.";
}
return Maybe<void>::Ok();
} else {
return Maybe<void>::Ok();
}
Expand Down
42 changes: 33 additions & 9 deletions oneflow/user/ops/reshape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,56 @@ Maybe<void> InferNdSbpFn(user_op::InferNdSbpFnContext* ctx) {
}

Maybe<void> LogicalTensorDescInferFn(user_op::InferContext* ctx) {
const Shape& shape = ctx->Attr<Shape>("shape");
Shape shape = ctx->Attr<Shape>("shape");
const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0);
user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0);
const Shape& in_shape = in_tensor_desc.shape();
Shape* out_shape = out_tensor_desc->mut_shape();
CHECK_OR_RETURN(in_tensor_desc.is_dynamic() == false);
*out_tensor_desc = in_tensor_desc;
CHECK_GE_OR_RETURN(shape.NumAxes(), 1);
DimVector dim_vec = {shape.dim_vec().begin(), shape.dim_vec().end()};
FOR_RANGE(int32_t, i, 0, dim_vec.size()) { CHECK_GE_OR_RETURN(dim_vec.at(i), 0); }
*out_shape = Shape(dim_vec);
if (in_shape.NumAxes() == 0 || shape.NumAxes() == 0) {
// NOTE(chengcheng): input/output Scalar
// do nothing
} else {
CHECK_GE_OR_RETURN(shape.NumAxes(), 1);
CHECK_GE_OR_RETURN(in_shape.NumAxes(), 1);
for (int i = 1 /* skip dim 0 */; i < shape.NumAxes(); ++i) {
// NOTE(chengcheng): ONLY dim-0 may be -1 for infer
CHECK_GE_OR_RETURN(shape.At(i), 0);
}
if (shape.At(0) == -1) {
// NOTE(chengcheng): dim-0 unchanged for input.
shape.Set(0, in_shape.At(0));
}
}
*out_shape = shape;
CHECK_EQ_OR_RETURN(out_shape->elem_cnt(), in_shape.elem_cnt());
return Maybe<void>::Ok();
}

Maybe<void> TensorDescInferFn(user_op::InferContext* ctx) {
const Shape& shape = ctx->Attr<Shape>("shape");
CHECK_GE_OR_RETURN(shape.NumAxes(), 1);
FOR_RANGE(int32_t, i, 0, shape.NumAxes()) { CHECK_GT_OR_RETURN(shape.At(i), 0); }
Shape shape = ctx->Attr<Shape>("shape");
const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0);
user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0);
const Shape& in_shape = in_tensor_desc.shape();
Shape* out_shape = out_tensor_desc->mut_shape();
CHECK_OR_RETURN(in_tensor_desc.is_dynamic() == false);
*out_tensor_desc->mut_shape() = in_tensor_desc.shape();
*out_tensor_desc->mut_is_dynamic() = in_tensor_desc.is_dynamic();
if (in_shape.NumAxes() == 0 || shape.NumAxes() == 0) {
// NOTE(chengcheng): input/output Scalar
// do nothing
} else {
CHECK_GE_OR_RETURN(shape.NumAxes(), 1);
CHECK_GE_OR_RETURN(in_shape.NumAxes(), 1);
for (int i = 1 /* skip dim 0 */; i < shape.NumAxes(); ++i) {
// NOTE(chengcheng): ONLY dim-0 may be -1 for infer
CHECK_GE_OR_RETURN(shape.At(i), 0);
}
if (shape.At(0) == -1) {
// NOTE(chengcheng): dim-0 unchanged for input.
shape.Set(0, in_shape.At(0));
}
}
const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0);
*out_shape = *JUST(GetPhysicalShape(shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx()));
CHECK_EQ_OR_RETURN(out_shape->elem_cnt(), in_shape.elem_cnt());
Expand Down
103 changes: 103 additions & 0 deletions python/oneflow/test/graph/test_graph_reshape_acc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import os
import numpy as np

import oneflow as flow
import oneflow.unittest


def _test_graph_reshape_acc(test_case):
class StageLayerModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = flow.nn.Linear(10, 8, False)
self.linear2 = flow.nn.Linear(8, 10, False)
flow.nn.init.constant_(self.linear1.weight, 0.023)
flow.nn.init.constant_(self.linear2.weight, 1.23)

def forward(self, x):
out0 = self.linear1(x)
out0 = flow.reshape(out0, (-1, 2, 4))
out0 = out0 + 1.0
out0 = out0 * 2.0
out0 = flow.reshape(out0, (-1, 8))
out1 = self.linear2(out0)
return out1

P0 = flow.placement("cuda", {0: [0]})
P1 = flow.placement("cuda", {0: [1]})
B = flow.sbp.broadcast

class PipelineModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.layer_0 = StageLayerModule()
self.layer_1 = StageLayerModule()
self.layer_0.to_consistent(P0, B)
self.layer_1.to_consistent(P1, B)

def forward(self, x):
# stage 0
x = flow.flatten(x, start_dim=1)
in0 = x.to_consistent(P0, B)
out0 = self.layer_0(in0)
# stage 1
in1 = out0.to_consistent(P1, B)
out1 = self.layer_1(in1)
return out1

pp_m = PipelineModule()
pp_m.train()
sgd = flow.optim.SGD(pp_m.parameters(), lr=0.001)

class PipelineGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.pp_m = pp_m
self.pp_m.layer_0.config.stage_id = 0
self.pp_m.layer_1.config.stage_id = 1
self.loss_fn = flow.nn.CrossEntropyLoss()
self.config.set_gradient_accumulation_steps(2)
self.add_optimizer(sgd)

def build(self, x, y):
out = self.pp_m(x)
y = y.to_consistent(P1, B)
loss = self.loss_fn(out, y)
loss.backward()
return loss

pp_g = PipelineGraph()

for i in range(20):
x = flow.randn(6, 2, 5)
y = flow.randint(0, 10, (6,))
x = x.to_consistent(P0, B)
y = y.to_consistent(P1, B)
out = pp_g(x, y)


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
@flow.unittest.skip_unless_1n2d()
class TestGraphReshapeAcc(oneflow.unittest.TestCase):
def test_graph_reshape_acc(test_case):
_test_graph_reshape_acc(test_case)


if __name__ == "__main__":
unittest.main()
14 changes: 14 additions & 0 deletions python/oneflow/test/modules/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ def _test_reshape_backward(test_case, device):
test_case.assertTrue(np.allclose(np_grad, input.grad.numpy(), 0.0001, 0.0001))


def _test_reshape_scalar(test_case, device):
x = flow.tensor(2.0, device=flow.device(device))
test_case.assertTrue(np.array_equal(x.shape, ()))
a = flow.reshape(x, (1,))
test_case.assertTrue(np.array_equal(a.shape, (1,)))
b = flow.reshape(x, (1, 1, 1, 1,))
test_case.assertTrue(np.array_equal(b.shape, (1, 1, 1, 1)))
c = flow.reshape(b, ())
test_case.assertTrue(np.array_equal(c.shape, ()))
d = flow.reshape(x, ())
test_case.assertTrue(np.array_equal(d.shape, ()))


@flow.unittest.skip_unless_1n1d()
class TestModule(flow.unittest.TestCase):
def test_reshape(test_case):
Expand All @@ -74,6 +87,7 @@ def test_reshape(test_case):
_test_reshape,
_test_reshape_tuple,
_test_reshape_backward,
_test_reshape_scalar,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
Expand Down

0 comments on commit 74775c1

Please sign in to comment.