Skip to content

reduce_lr_on_plateau

ReduceLROnPlateau

Bases: Trace

Reduce learning rate based on evaluation results.

Parameters:

Name Type Description Default
model Union[Model, Module]

A model instance compiled with fe.build.

required
metric Optional[str]

The metric name to be monitored. If None, the model's validation loss will be used as the metric.

None
patience int

Number of epochs to wait before reducing LR again.

10
factor float

Reduce factor for the learning rate.

0.1
best_mode str

Higher is better ("max") or lower is better ("min").

'min'
min_lr float

Minimum learning rate.

1e-06

Raises:

Type Description
AssertionError

If the loss cannot be inferred from the model and a metric was not provided.

Source code in fastestimator/fastestimator/trace/adapt/reduce_lr_on_plateau.py
@traceable()
class ReduceLROnPlateau(Trace):
    """Reduce learning rate based on evaluation results.

    Args:
        model: A model instance compiled with fe.build.
        metric: The metric name to be monitored. If None, the model's validation loss will be used as the metric.
        patience: Number of epochs to wait before reducing LR again.
        factor: Reduce factor for the learning rate.
        best_mode: Higher is better ("max") or lower is better ("min").
        min_lr: Minimum learning rate.

    Raises:
        AssertionError: If the loss cannot be inferred from the `model` and a `metric` was not provided.
    """
    system: System

    def __init__(self,
                 model: Union[tf.keras.Model, torch.nn.Module],
                 metric: Optional[str] = None,
                 patience: int = 10,
                 factor: float = 0.1,
                 best_mode: str = "min",
                 min_lr: float = 1e-6) -> None:
        if not metric:
            assert hasattr(model, "loss_name"), \
                "ReduceLROnPlateau cannot infer model loss name. Provide a metric or use the model in an UpdateOp."
            assert len(model.loss_name) == 1, "the model has more than one losses, please provide the metric explicitly"
            metric = next(iter(model.loss_name))
        super().__init__(mode="eval", inputs=metric, outputs=model.model_name + "_lr")
        self.fe_monitor_names.add(metric)
        self.model = model
        self.patience = patience
        self.factor = factor
        self.best_mode = best_mode
        self.min_lr = min_lr
        self.wait = 0
        if self.best_mode == "min":
            self.best = np.Inf
            self.monitor_op = lt
        elif self.best_mode == "max":
            self.best = -np.Inf
            self.monitor_op = gt
        else:
            raise ValueError("best_mode must be either 'min' or 'max'")

    def on_epoch_end(self, data: Data) -> None:
        if self.monitor_op(data[self.inputs[0]], self.best):
            self.best = data[self.inputs[0]]
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                new_lr = max(self.min_lr, np.float32(self.factor * get_lr(self.model)))
                set_lr(self.model, new_lr)
                self.wait = 0
                data.write_with_log(self.outputs[0], new_lr)
                print("FastEstimator-ReduceLROnPlateau: learning rate reduced to {}".format(new_lr))