Skip to content

Commit

Permalink
graph support eager lrs (#6262)
Browse files Browse the repository at this point in the history
* add multistep lr, refine

* add steplr and consine annealing lr for graph

Co-authored-by: oneflow-ci-bot <[email protected]>
  • Loading branch information
strint and oneflow-ci-bot authored Sep 13, 2021
1 parent 74775c1 commit 76e78fd
Show file tree
Hide file tree
Showing 20 changed files with 287 additions and 12 deletions.
18 changes: 18 additions & 0 deletions oneflow/core/job/learning_rate_schedule_conf.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ message CosineDecayConf {
optional double alpha = 2 [default = 0.0];
}

message CosineAnnealingDecayConf {
required int64 t_max = 1;
optional double eta_min = 2 [default = 0.0];
}

message LinearCosineDecayConf {
required int64 decay_batches = 1;
optional double num_periods = 2 [default = 0.5];
Expand All @@ -48,6 +53,16 @@ message PiecewiseScalingConf {
repeated double scales = 2;
}

message StepConf {
required int64 step_size = 1;
optional double gamma = 2 [default = 0.1];
}

message MultiStepConf {
repeated int64 milestones = 1;
optional double gamma = 2 [default = 0.1];
}

message LearningRateDecayConf {
oneof type {
ExponentialDecayConf exponential_conf = 2000;
Expand All @@ -58,6 +73,9 @@ message LearningRateDecayConf {
CosineDecayConf cosine_conf = 2005;
LinearCosineDecayConf linear_cosine_conf = 2006;
PiecewiseScalingConf piecewise_scaling_conf = 2007;
MultiStepConf multi_step_conf = 2008;
StepConf step_conf = 2009;
CosineAnnealingDecayConf cosine_annealing_conf = 2010;
}
}

Expand Down
49 changes: 49 additions & 0 deletions oneflow/core/kernel/learning_rate_schedule_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,20 @@ double CosineDecayedLearningRate(const CosineDecayConf& conf, double lr, int64_t
return lr * decayed;
}

double CosineAnnealingDecayedLearningRate(const CosineAnnealingDecayConf& conf, double lr,
int64_t cur_batch_num) {
CHECK_GT(conf.t_max(), 0);
if (0 == cur_batch_num) { return lr; }

const double PI = std::atan(1.0) * 4.0;
const double eta_min = conf.eta_min();
CHECK_LT(eta_min, lr);
const double t_max_d = static_cast<double>(conf.t_max());
const double cur_batch_num_d = static_cast<double>(cur_batch_num);

return eta_min + (((lr - eta_min) * (1 + std::cos(PI * (cur_batch_num_d / t_max_d)))) / 2);
}

double LinearCosineDecayedLearningRate(const LinearCosineDecayConf& conf, double lr,
int64_t cur_batch_num) {
CHECK_GT(conf.decay_batches(), 0);
Expand All @@ -174,6 +188,35 @@ double PiecewiseScalingLearningRate(const PiecewiseScalingConf& conf, double lr,
return scales[i] * lr;
}

double StepLearningRate(const StepConf& conf, double lr, int64_t cur_batch_num) {
const int64_t step_size = conf.step_size();
CHECK_GE(step_size, 1);
const double gamma = conf.gamma();

double cur_batch = static_cast<double>(cur_batch_num);
double step = static_cast<double>(step_size);
size_t i = std::floor(cur_batch / step);

return lr * std::pow(gamma, i);
}

double MultiStepLearningRate(const MultiStepConf& conf, double lr, int64_t cur_batch_num) {
const PbRf<int64_t>& milestones = conf.milestones();
CHECK_GE(milestones.size(), 1);
const double gamma = conf.gamma();

size_t i = 0;
if (cur_batch_num < milestones[milestones.size() - 1]) {
for (; i < milestones.size(); ++i) {
if (cur_batch_num < milestones[i]) { break; }
}
} else {
i = milestones.size();
}

return lr * std::pow(gamma, i);
}

double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int64_t cur_batch_num) {
if (conf.has_exponential_conf()) {
return ExponentialDecayedLearningRate(conf.exponential_conf(), lr, cur_batch_num);
Expand All @@ -187,10 +230,16 @@ double GetDecayedLearningRate(const LearningRateDecayConf& conf, double lr, int6
return PolynomialDecayedLearningRate(conf.polynomial_conf(), lr, cur_batch_num);
} else if (conf.has_cosine_conf()) {
return CosineDecayedLearningRate(conf.cosine_conf(), lr, cur_batch_num);
} else if (conf.has_cosine_annealing_conf()) {
return CosineAnnealingDecayedLearningRate(conf.cosine_annealing_conf(), lr, cur_batch_num);
} else if (conf.has_linear_cosine_conf()) {
return LinearCosineDecayedLearningRate(conf.linear_cosine_conf(), lr, cur_batch_num);
} else if (conf.has_piecewise_scaling_conf()) {
return PiecewiseScalingLearningRate(conf.piecewise_scaling_conf(), lr, cur_batch_num);
} else if (conf.has_step_conf()) {
return StepLearningRate(conf.step_conf(), lr, cur_batch_num);
} else if (conf.has_multi_step_conf()) {
return MultiStepLearningRate(conf.multi_step_conf(), lr, cur_batch_num);
} else {
UNIMPLEMENTED();
}
Expand Down
4 changes: 2 additions & 2 deletions python/oneflow/amp/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
)
self._growth_interval = growth_interval

def generate_conf_for_graph(self, train_conf):
def _generate_conf_for_graph(self, train_conf):
train_conf.mutable_dynamic_loss_scale_policy().set_initial_loss_scale(
self._init_scale
)
Expand All @@ -52,5 +52,5 @@ def __init__(self, scale_factor):

self._scale_factor = scale_factor

def generate_conf_for_graph(self, train_conf):
def _generate_conf_for_graph(self, train_conf):
train_conf.set_loss_scale_factor(self._scale_factor)
2 changes: 1 addition & 1 deletion python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _generate_config_proto(self):
self.config.proto.set_job_name(self._name)

if self._grad_scaler is not None:
self._grad_scaler.generate_conf_for_graph(
self._grad_scaler._generate_conf_for_graph(
self.config.proto.mutable_train_conf()
)

Expand Down
4 changes: 2 additions & 2 deletions python/oneflow/nn/graph/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(

def generate_optimizer_and_variable_configs(self, train_conf, vars_conf):
if self._optimizer is not None:
opt_confs = self._optimizer.generate_conf_for_graph(train_conf, vars_conf)
opt_confs = self._optimizer._generate_conf_for_graph(train_conf, vars_conf)
if self._lr_scheduler is not None:
self._lr_scheduler.generate_conf_for_graph(opt_confs)
self._lr_scheduler._generate_conf_for_graph(opt_confs)


class VariableConfig(object):
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/nn/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def step(self, closure: Callable = None):

return loss

def generate_conf_for_graph(self, train_conf, vars_conf):
def _generate_conf_for_graph(self, train_conf, vars_conf):
new_opt_confs = []
for param_group in self.param_groups:
optimizer_conf = train_conf.mutable_optimizer_conf().Add()
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/nn/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def step(self, closure: Callable = None):
self._state["step"] += 1
return loss

def generate_conf_for_graph(self, train_conf, vars_conf):
def _generate_conf_for_graph(self, train_conf, vars_conf):
new_opt_confs = []
for param_group in self.param_groups:
optimizer_conf = train_conf.mutable_optimizer_conf().Add()
Expand Down
10 changes: 10 additions & 0 deletions python/oneflow/nn/optimizer/cosine_annealing_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,13 @@ def get_lr(self):
+ self.eta_min
for group in self._optimizer.param_groups
]

def _generate_conf_for_graph(self, opt_confs):
for opt_conf in opt_confs:
learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay()
learning_rate_decay_conf.mutable_cosine_annealing_conf().set_t_max(
self.T_max
)
learning_rate_decay_conf.mutable_cosine_annealing_conf().set_eta_min(
self.eta_min
)
2 changes: 1 addition & 1 deletion python/oneflow/nn/optimizer/cosine_decay_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_lr(self):
else:
return [base_lr * self.alpha for base_lr in self.base_lrs]

def generate_conf_for_graph(self, opt_confs):
def _generate_conf_for_graph(self, opt_confs):
# CosineDecayLR is the same as CosineDecayConf in nn.Graph
for opt_conf in opt_confs:
learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay()
Expand Down
9 changes: 9 additions & 0 deletions python/oneflow/nn/optimizer/multistep_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,12 @@ def get_lr(self):
return [group["lr"] for group in self._optimizer.param_groups]
else:
return [group["lr"] * self.gamma for group in self._optimizer.param_groups]

def _generate_conf_for_graph(self, opt_confs):
for opt_conf in opt_confs:
learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay()
for milestone in self.milestones:
learning_rate_decay_conf.mutable_multi_step_conf().add_milestones(
milestone
)
learning_rate_decay_conf.mutable_multi_step_conf().set_gamma(self.gamma)
2 changes: 1 addition & 1 deletion python/oneflow/nn/optimizer/rmsprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def step(self, closure: Callable = None):
self._state["step"] = self._state["step"] + 1
return loss

def generate_conf_for_graph(self, train_conf, vars_conf):
def _generate_conf_for_graph(self, train_conf, vars_conf):
new_opt_confs = []
for param_group in self.param_groups:
optimizer_conf = train_conf.mutable_optimizer_conf().Add()
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/nn/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def step(self, closure: Callable = None):
self._state["step"] = self._state["step"] + 1
return loss

def generate_conf_for_graph(self, train_conf, vars_conf):
def _generate_conf_for_graph(self, train_conf, vars_conf):
new_opt_confs = []
for param_group in self.param_groups:
optimizer_conf = train_conf.mutable_optimizer_conf().Add()
Expand Down
6 changes: 6 additions & 0 deletions python/oneflow/nn/optimizer/step_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@ def get_lr(self):
return [group["lr"] for group in self._optimizer.param_groups]
else:
return [group["lr"] * self.gamma for group in self._optimizer.param_groups]

def _generate_conf_for_graph(self, opt_confs):
for opt_conf in opt_confs:
learning_rate_decay_conf = opt_conf.mutable_learning_rate_decay()
learning_rate_decay_conf.mutable_step_conf().set_step_size(self.step_size)
learning_rate_decay_conf.mutable_step_conf().set_gamma(self.gamma)
4 changes: 2 additions & 2 deletions python/oneflow/nn/optimizer/warm_up_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def get_lr(self):
"got {}".format(self.warmup_method)
)

def generate_conf_for_graph(self, opt_confs):
def _generate_conf_for_graph(self, opt_confs):
if self._inner_lr_sch is not None:
self._inner_lr_sch.generate_conf_for_graph(opt_confs)
self._inner_lr_sch._generate_conf_for_graph(opt_confs)
if self.warmup_method == "linear":
for opt_conf in opt_confs:
warmup_conf = opt_conf.mutable_warmup_conf()
Expand Down
Loading

0 comments on commit 76e78fd

Please sign in to comment.