diff --git a/oneflow/core/job/learning_rate_schedule_conf.proto b/oneflow/core/job/learning_rate_schedule_conf.proto index 47d297a360f..cc615b99988 100644 --- a/oneflow/core/job/learning_rate_schedule_conf.proto +++ b/oneflow/core/job/learning_rate_schedule_conf.proto @@ -36,6 +36,11 @@ message CosineDecayConf { optional double alpha = 2 [default = 0.0]; } +message CosineAnnealingDecayConf { + required int64 t_max = 1; + optional double eta_min = 2 [default = 0.0]; +} + message LinearCosineDecayConf { required int64 decay_batches = 1; optional double num_periods = 2 [default = 0.5]; @@ -48,6 +53,16 @@ message PiecewiseScalingConf { repeated double scales = 2; } +message StepConf { + required int64 step_size = 1; + optional double gamma = 2 [default = 0.1]; +} + +message MultiStepConf { + repeated int64 milestones = 1; + optional double gamma = 2 [default = 0.1]; +} + message LearningRateDecayConf { oneof type { ExponentialDecayConf exponential_conf = 2000; @@ -58,6 +73,9 @@ message LearningRateDecayConf { CosineDecayConf cosine_conf = 2005; LinearCosineDecayConf linear_cosine_conf = 2006; PiecewiseScalingConf piecewise_scaling_conf = 2007; + MultiStepConf multi_step_conf = 2008; + StepConf step_conf = 2009; + CosineAnnealingDecayConf cosine_annealing_conf = 2010; } } diff --git a/oneflow/core/kernel/learning_rate_schedule_kernel.cpp b/oneflow/core/kernel/learning_rate_schedule_kernel.cpp index 65c075e2803..254edfb23a5 100644 --- a/oneflow/core/kernel/learning_rate_schedule_kernel.cpp +++ b/oneflow/core/kernel/learning_rate_schedule_kernel.cpp @@ -148,6 +148,20 @@ double CosineDecayedLearningRate(const CosineDecayConf& conf, double lr, int64_t return lr * decayed; } +double CosineAnnealingDecayedLearningRate(const CosineAnnealingDecayConf& conf, double lr, + int64_t cur_batch_num) { + CHECK_GT(conf.t_max(), 0); + if (0 == cur_batch_num) { return lr; } + + const double PI = std::atan(1.0) * 4.0; + const double eta_min = conf.eta_min(); + CHECK_LT(eta_min, lr); + const double t_max_d = static_cast(conf.t_max()); + const double cur_batch_num_d = static_cast(cur_batch_num); + + return eta_min + (((lr - eta_min) * (1 + std::cos(PI * (cur_batch_num_d / t_max_d)))) / 2); +} + double LinearCosineDecayedLearningRate(const LinearCosineDecayConf& conf, double lr, int64_t cur_batch_num) { CHECK_GT(conf.decay_batches(), 0); @@ -174,6 +188,35 @@ double PiecewiseScalingLearningRate(const PiecewiseScalingConf& conf, double lr, return scales[i] * lr; } +double StepLearningRate(const StepConf& conf, double lr, int64_t cur_batch_num) { + const int64_t step_size = conf.step_size(); + CHECK_GE(step_size, 1); + const double gamma = conf.gamma(); + + double cur_batch = static_cast(cur_batch_num); + double step = static_cast(step_size); + size_t i = std::floor(cur_batch / step); + + return lr * std::pow(gamma, i); +} + +double MultiStepLearningRate(const MultiStepConf& conf, double lr, int64_t cur_batch_num) { + const PbRf& milestones = conf.milestones(); + CHECK_GE(milestones.size(), 1); + const double gamma = conf.gamma(); + + size_t i = 0; + if (cur_batch_num < milestones[milestones.size() - 1]) { + for (; i < milestones.size(); ++i) { + if (cur_batch_num < milestones[i]) { break; } + } + } else { + i = milestones.size(); + } + + return lr * std::pow(gamma, i); +} + double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) { if (conf.has_exponential_conf()) { return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num); @@ -187,10 +230,16 @@ double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int6 return PolynomialDecayedLearningRate(conf.polynomial_conf(), lr, cur_batch_num); } else if (conf.has_cosine_conf()) { return CosineDecayedLearningRate(conf.cosine_conf(), lr, cur_batch_num); + } else if (conf.has_cosine_annealing_conf()) { + return CosineAnnealingDecayedLearningRate(conf.cosine_annealing_conf(), lr, cur_batch_num); } else if (conf.has_linear_cosine_conf()) { return LinearCosineDecayedLearningRate(conf.linear_cosine_conf(), lr, cur_batch_num); } else if (conf.has_piecewise_scaling_conf()) { return PiecewiseScalingLearningRate(conf.piecewise_scaling_conf(), lr, cur_batch_num); + } else if (conf.has_step_conf()) { + return StepLearningRate(conf.step_conf(), lr, cur_batch_num); + } else if (conf.has_multi_step_conf()) { + return MultiStepLearningRate(conf.multi_step_conf(), lr, cur_batch_num); } else { UNIMPLEMENTED(); } diff --git a/python/oneflow/amp/grad_scaler.py b/python/oneflow/amp/grad_scaler.py index d5e84b49f64..374e920b486 100644 --- a/python/oneflow/amp/grad_scaler.py +++ b/python/oneflow/amp/grad_scaler.py @@ -33,7 +33,7 @@ def __init__( ) self._growth_interval = growth_interval - def generate_conf_for_graph(self, train_conf): + def _generate_conf_for_graph(self, train_conf): train_conf.mutable_dynamic_loss_scale_policy().set_initial_loss_scale( self._init_scale ) @@ -52,5 +52,5 @@ def __init__(self, scale_factor): self._scale_factor = scale_factor - def generate_conf_for_graph(self, train_conf): + def _generate_conf_for_graph(self, train_conf): train_conf.set_loss_scale_factor(self._scale_factor) diff --git a/python/oneflow/nn/graph/graph.py b/python/oneflow/nn/graph/graph.py index 450f97b0633..c73a9d35d3b 100644 --- a/python/oneflow/nn/graph/graph.py +++ b/python/oneflow/nn/graph/graph.py @@ -386,7 +386,7 @@ def _generate_config_proto(self): self.config.proto.set_job_name(self._name) if self._grad_scaler is not None: - self._grad_scaler.generate_conf_for_graph( + self._grad_scaler._generate_conf_for_graph( self.config.proto.mutable_train_conf() ) diff --git a/python/oneflow/nn/graph/optimizer.py b/python/oneflow/nn/graph/optimizer.py index 8e9811d70ab..e0c6e422c33 100644 --- a/python/oneflow/nn/graph/optimizer.py +++ b/python/oneflow/nn/graph/optimizer.py @@ -36,9 +36,9 @@ def __init__( def generate_optimizer_and_variable_configs(self, train_conf, vars_conf): if self._optimizer is not None: - opt_confs = self._optimizer.generate_conf_for_graph(train_conf, vars_conf) + opt_confs = self._optimizer._generate_conf_for_graph(train_conf, vars_conf) if self._lr_scheduler is not None: - self._lr_scheduler.generate_conf_for_graph(opt_confs) + self._lr_scheduler._generate_conf_for_graph(opt_confs) class VariableConfig(object): diff --git a/python/oneflow/nn/optimizer/adam.py b/python/oneflow/nn/optimizer/adam.py index ce073508d83..d5e5d38fc20 100644 --- a/python/oneflow/nn/optimizer/adam.py +++ b/python/oneflow/nn/optimizer/adam.py @@ -207,7 +207,7 @@ def step(self, closure: Callable = None): return loss - def generate_conf_for_graph(self, train_conf, vars_conf): + def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: optimizer_conf = train_conf.mutable_optimizer_conf().Add() diff --git a/python/oneflow/nn/optimizer/adamw.py b/python/oneflow/nn/optimizer/adamw.py index 508d7623689..bd7696ee909 100644 --- a/python/oneflow/nn/optimizer/adamw.py +++ b/python/oneflow/nn/optimizer/adamw.py @@ -209,7 +209,7 @@ def step(self, closure: Callable = None): self._state["step"] += 1 return loss - def generate_conf_for_graph(self, train_conf, vars_conf): + def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: optimizer_conf = train_conf.mutable_optimizer_conf().Add() diff --git a/python/oneflow/nn/optimizer/cosine_annealing_lr.py b/python/oneflow/nn/optimizer/cosine_annealing_lr.py index ce4172e3292..e0149fb8aeb 100644 --- a/python/oneflow/nn/optimizer/cosine_annealing_lr.py +++ b/python/oneflow/nn/optimizer/cosine_annealing_lr.py @@ -84,3 +84,13 @@ def get_lr(self): + self.eta_min for group in self._optimizer.param_groups ] + + def _generate_conf_for_graph(self, opt_confs): + for opt_conf in opt_confs: + learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay() + learning_rate_decay_conf.mutable_cosine_annealing_conf().set_t_max( + self.T_max + ) + learning_rate_decay_conf.mutable_cosine_annealing_conf().set_eta_min( + self.eta_min + ) diff --git a/python/oneflow/nn/optimizer/cosine_decay_lr.py b/python/oneflow/nn/optimizer/cosine_decay_lr.py index 50a612d41a3..3c1b3f83630 100644 --- a/python/oneflow/nn/optimizer/cosine_decay_lr.py +++ b/python/oneflow/nn/optimizer/cosine_decay_lr.py @@ -89,7 +89,7 @@ def get_lr(self): else: return [base_lr * self.alpha for base_lr in self.base_lrs] - def generate_conf_for_graph(self, opt_confs): + def _generate_conf_for_graph(self, opt_confs): # CosineDecayLR is the same as CosineDecayConf in nn.Graph for opt_conf in opt_confs: learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay() diff --git a/python/oneflow/nn/optimizer/multistep_lr.py b/python/oneflow/nn/optimizer/multistep_lr.py index df35d97d1ec..6262d9ed881 100644 --- a/python/oneflow/nn/optimizer/multistep_lr.py +++ b/python/oneflow/nn/optimizer/multistep_lr.py @@ -65,3 +65,12 @@ def get_lr(self): return [group["lr"] for group in self._optimizer.param_groups] else: return [group["lr"] * self.gamma for group in self._optimizer.param_groups] + + def _generate_conf_for_graph(self, opt_confs): + for opt_conf in opt_confs: + learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay() + for milestone in self.milestones: + learning_rate_decay_conf.mutable_multi_step_conf().add_milestones( + milestone + ) + learning_rate_decay_conf.mutable_multi_step_conf().set_gamma(self.gamma) diff --git a/python/oneflow/nn/optimizer/rmsprop.py b/python/oneflow/nn/optimizer/rmsprop.py index dd3fae8a1ae..ce7aaacb1e8 100644 --- a/python/oneflow/nn/optimizer/rmsprop.py +++ b/python/oneflow/nn/optimizer/rmsprop.py @@ -214,7 +214,7 @@ def step(self, closure: Callable = None): self._state["step"] = self._state["step"] + 1 return loss - def generate_conf_for_graph(self, train_conf, vars_conf): + def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: optimizer_conf = train_conf.mutable_optimizer_conf().Add() diff --git a/python/oneflow/nn/optimizer/sgd.py b/python/oneflow/nn/optimizer/sgd.py index 85fca7b4d80..f726cee2a25 100644 --- a/python/oneflow/nn/optimizer/sgd.py +++ b/python/oneflow/nn/optimizer/sgd.py @@ -163,7 +163,7 @@ def step(self, closure: Callable = None): self._state["step"] = self._state["step"] + 1 return loss - def generate_conf_for_graph(self, train_conf, vars_conf): + def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs = [] for param_group in self.param_groups: optimizer_conf = train_conf.mutable_optimizer_conf().Add() diff --git a/python/oneflow/nn/optimizer/step_lr.py b/python/oneflow/nn/optimizer/step_lr.py index e978b789ad2..b72d5bde746 100644 --- a/python/oneflow/nn/optimizer/step_lr.py +++ b/python/oneflow/nn/optimizer/step_lr.py @@ -57,3 +57,9 @@ def get_lr(self): return [group["lr"] for group in self._optimizer.param_groups] else: return [group["lr"] * self.gamma for group in self._optimizer.param_groups] + + def _generate_conf_for_graph(self, opt_confs): + for opt_conf in opt_confs: + learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay() + learning_rate_decay_conf.mutable_step_conf().set_step_size(self.step_size) + learning_rate_decay_conf.mutable_step_conf().set_gamma(self.gamma) diff --git a/python/oneflow/nn/optimizer/warm_up_lr.py b/python/oneflow/nn/optimizer/warm_up_lr.py index 76179c11b8b..04cd8464c7c 100644 --- a/python/oneflow/nn/optimizer/warm_up_lr.py +++ b/python/oneflow/nn/optimizer/warm_up_lr.py @@ -136,9 +136,9 @@ def get_lr(self): "got {}".format(self.warmup_method) ) - def generate_conf_for_graph(self, opt_confs): + def _generate_conf_for_graph(self, opt_confs): if self._inner_lr_sch is not None: - self._inner_lr_sch.generate_conf_for_graph(opt_confs) + self._inner_lr_sch._generate_conf_for_graph(opt_confs) if self.warmup_method == "linear": for opt_conf in opt_confs: warmup_conf = opt_conf.mutable_warmup_conf() diff --git a/python/oneflow/test/graph/test_graph_lr_scheduler.py b/python/oneflow/test/graph/test_graph_lr_with_warmup.py similarity index 100% rename from python/oneflow/test/graph/test_graph_lr_scheduler.py rename to python/oneflow/test/graph/test_graph_lr_with_warmup.py diff --git a/python/oneflow/test/graph/test_graph_lrs.py b/python/oneflow/test/graph/test_graph_lrs.py new file mode 100644 index 00000000000..6dac3a41270 --- /dev/null +++ b/python/oneflow/test/graph/test_graph_lrs.py @@ -0,0 +1,183 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math +import unittest +import os +import numpy as np + +import oneflow as flow +import oneflow.unittest +from oneflow.nn.parameter import Parameter + + +def _test_linear_graph_train_with_lr_sch( + test_case, iter_num, device, get_opt_and_lr_sch +): + def train_with_module(iter_num=3): + linear = flow.nn.Linear(3, 8) + linear = linear.to(device) + flow.nn.init.constant_(linear.weight, -0.68758) + flow.nn.init.constant_(linear.bias, 0.23) + + opt, lr_sch = get_opt_and_lr_sch(linear.parameters()) + + x = flow.tensor( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=flow.float32, + device=device, + requires_grad=False, + ) + + def one_iter(): + of_out = linear(x) + of_out = of_out.sum() + + of_out.backward() + opt.step() + if lr_sch is not None: + lr_sch.step() + opt.zero_grad() + + return of_out.numpy(), linear.weight.numpy() + + check_list = [] + for i in range(iter_num): + check_list.append(one_iter()) + return check_list + + def train_with_graph(iter_num=3): + linear = flow.nn.Linear(3, 8) + linear = linear.to(device) + flow.nn.init.constant_(linear.weight, -0.68758) + flow.nn.init.constant_(linear.bias, 0.23) + + opt, lr_sch = get_opt_and_lr_sch(linear.parameters()) + + x = flow.tensor( + [ + [-0.94630778, -0.83378579, -0.87060891], + [2.0289922, -0.28708987, -2.18369248], + [0.35217619, -0.67095644, -1.58943879], + [0.08086036, -1.81075924, 1.20752494], + [0.8901075, -0.49976737, -1.07153746], + [-0.44872912, -1.07275683, 0.06256855], + [-0.22556897, 0.74798368, 0.90416439], + [0.48339456, -2.32742195, -0.59321527], + ], + dtype=flow.float32, + device=device, + requires_grad=False, + ) + + class LinearTrainGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.linear = linear + if lr_sch is None: + self.add_optimizer(opt) + else: + self.add_optimizer(opt, lr_sch=lr_sch) + + def build(self, x): + out = self.linear(x) + out = out.sum() + out.backward() + return out + + linear_t_g = LinearTrainGraph() + + def one_iter(): + of_graph_out = linear_t_g(x) + return of_graph_out.numpy(), linear_t_g.linear.weight.origin.numpy() + + check_list = [] + for i in range(iter_num): + check_list.append(one_iter()) + return check_list + + module_check_list = train_with_module(iter_num) + graph_check_list = train_with_graph(iter_num) + for i in range(iter_num): + # check equal on loss + test_case.assertTrue( + np.allclose( + module_check_list[i][0], + graph_check_list[i][0], + rtol=0.00001, + atol=0.00001, + ) + ) + # check equal on weight + test_case.assertTrue( + np.allclose( + module_check_list[i][1], + graph_check_list[i][1], + rtol=0.00001, + atol=0.00001, + ) + ) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n1d() +class TestGraphLRs(flow.unittest.TestCase): + def test_step_lr(test_case): + def _lr_fn(parameters): + of_sgd = flow.optim.SGD(parameters, lr=0.001) + + step_lr = flow.optim.lr_scheduler.StepLR(of_sgd, step_size=7, gamma=0.1) + return of_sgd, step_lr + + _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn) + _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn) + + def test_multistep_lr(test_case): + def _lr_fn(parameters): + of_sgd = flow.optim.SGD(parameters, lr=0.001) + + multistep_lr = flow.optim.lr_scheduler.MultiStepLR( + of_sgd, milestones=[10, 15], gamma=0.1 + ) + return of_sgd, multistep_lr + + _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn) + _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn) + + def test_cosine_annealing_lr(test_case): + def _lr_fn(parameters): + of_sgd = flow.optim.SGD(parameters, lr=0.001) + + lr = flow.optim.lr_scheduler.CosineAnnealingLR( + of_sgd, T_max=5, eta_min=0.0001 + ) + return of_sgd, lr + + _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cuda"), _lr_fn) + _test_linear_graph_train_with_lr_sch(test_case, 21, flow.device("cpu"), _lr_fn) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_graph_adam_optim.py b/python/oneflow/test/graph/test_graph_optim_adam.py similarity index 100% rename from python/oneflow/test/graph/test_graph_adam_optim.py rename to python/oneflow/test/graph/test_graph_optim_adam.py diff --git a/python/oneflow/test/graph/test_graph_adamw_optim.py b/python/oneflow/test/graph/test_graph_optim_adamw.py similarity index 100% rename from python/oneflow/test/graph/test_graph_adamw_optim.py rename to python/oneflow/test/graph/test_graph_optim_adamw.py diff --git a/python/oneflow/test/graph/test_graph_rmsprop_optim.py b/python/oneflow/test/graph/test_graph_optim_rmsprop.py similarity index 100% rename from python/oneflow/test/graph/test_graph_rmsprop_optim.py rename to python/oneflow/test/graph/test_graph_optim_rmsprop.py diff --git a/python/oneflow/test/graph/test_graph_sgd_optim.py b/python/oneflow/test/graph/test_graph_optim_sgd.py similarity index 100% rename from python/oneflow/test/graph/test_graph_sgd_optim.py rename to python/oneflow/test/graph/test_graph_optim_sgd.py