Table Of Contents
Table Of Contents

Source code for gluonts.trainer.learning_rate_scheduler

# Copyright 2018, Inc. or its affiliates. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
# or in the "license" file accompanying this file. This file is distributed
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Third-party imports
import mxnet as mx
import numpy as np

[docs]class MetricAttentiveScheduler(mx.lr_scheduler.LRScheduler): r""" This scheduler decreases the learning rate based on the value of some validation metric to be optimized (maximized or minimized). The value of such metric is provided by calling the `step` method on the scheduler. A `patience` parameter must be provided, and the scheduler will reduce the learning rate if no improvement in the metric is done before `patience` observations of the metric. Examples: `patience = 0`: learning rate will decrease at every call to `step`, regardless of the metric value `patience = 1`: learning rate is reduced as soon `step` is called with a metric value which does not improve over the best encountered `patience = 10`: learning rate is reduced if no improvement in the metric is recorded in 10 successive calls to `step` Parameters ---------- objective String, can either be `"min"` or `"max"` patience The patience to observe before reducing the learning rate, nonnegative integer. base_lr Initial learning rate to be used. decay_factor Factor (between 0 and 1) by which to decrease the learning rate. min_lr Lower bound for the learning rate, learning rate will never go below `min_lr` """ def __init__( self, objective: str, patience: int, base_lr: float = 0.01, decay_factor: float = 0.5, min_lr: float = 0.0, ) -> None: assert base_lr > 0, f"base_lr should be positive, got {base_lr}" assert ( base_lr > min_lr ), f"base_lr should greater than min_lr, {base_lr} <= {min_lr}" assert ( 0 < decay_factor < 1 ), f"decay_factor factor should be between 0 and 1, got {decay_factor}" assert patience >= 0, f"patience should be nonnegative, got {patience}" assert objective in [ "min", "max", ], f"objective should be 'min' or 'max', got {objective}" super(MetricAttentiveScheduler, self).__init__(base_lr=base_lr) self.decay_factor = decay_factor self.patience = patience self.objective = objective self.min_lr = min_lr self.best_metric = np.Inf if objective == "min" else -np.Inf self.prev_change = 0 self.epoch_no = 0 self.curr_lr = None def __call__(self, num_update: int) -> float: if self.curr_lr is None: self.curr_lr = self.base_lr assert self.curr_lr is not None return self.curr_lr
[docs] def step(self, metric_value: float) -> None: """ Inform the scheduler of the new value of the metric that is being optimized. This method should be invoked at regular intervals (e.g. at the end of every epoch, after computing a validation score). Parameters ---------- metric_value Value of the metric that is being optimized. """ if self.curr_lr is None: self.curr_lr = self.base_lr assert self.curr_lr is not None metric_improved = ( self.objective == "min" and metric_value < self.best_metric ) or (self.objective == "max" and metric_value > self.best_metric) if metric_improved: self.best_metric = metric_value self.prev_change = self.epoch_no if self.epoch_no - self.prev_change >= self.patience: self.curr_lr = max(self.min_lr, self.decay_factor * self.curr_lr) self.prev_change = self.epoch_no self.epoch_no += 1