Skip to content

early_stopping

EarlyStopping

Bases: Trace

Stop training when a monitored quantity has stopped improving.

Parameters:

Name Type Description Default
monitor str

Quantity to be monitored.

'loss'
min_delta float

Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta will count as no improvement.

0.0
patience int

Number of epochs with no improvement after which training will be stopped.

0
compare str

One of {"min", "max"}. In "min" mode, training will stop when the quantity monitored has stopped decreasing; in max mode it will stop when the quantity monitored has stopped increasing.

'min'
baseline Optional[float]

Baseline value for the monitored quantity. Training will stop if the model doesn't show improvement over the baseline.

None
mode str

What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

'eval'

Raises:

Type Description
ValueError

If compare is an invalid value or more than one monitor is provided.

Source code in fastestimator/fastestimator/trace/adapt/early_stopping.py
@traceable()
class EarlyStopping(Trace):
    """Stop training when a monitored quantity has stopped improving.

    Args:
        monitor: Quantity to be monitored.
        min_delta: Minimum change in the monitored quantity to qualify as an improvement, i.e. an
            absolute change of less than min_delta will count as no improvement.
        patience: Number of epochs with no improvement after which training will be stopped.
        compare: One of {"min", "max"}. In "min" mode, training will stop when the quantity monitored
            has stopped decreasing; in `max` mode it will stop when the quantity monitored has stopped increasing.
        baseline: Baseline value for the monitored quantity. Training will stop if the model doesn't
            show improvement over the baseline.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".

    Raises:
        ValueError: If `compare` is an invalid value or more than one `monitor` is provided.
    """
    def __init__(self,
                 monitor: str = "loss",
                 min_delta: float = 0.0,
                 patience: int = 0,
                 compare: str = 'min',
                 baseline: Optional[float] = None,
                 mode: str = 'eval') -> None:
        super().__init__(inputs=monitor, mode=mode)

        if len(self.inputs) != 1:
            raise ValueError("EarlyStopping supports only one monitor key")
        if compare not in ['min', 'max']:
            raise ValueError("compare_mode can only be `min` or `max`")

        self.monitored_key = monitor
        self.fe_monitor_names.add(monitor)
        self.min_delta = abs(min_delta)
        self.wait = 0
        self.best = 0
        self.patience = patience
        self.baseline = baseline
        if compare == 'min':
            self.monitor_op = lt
            self.min_delta *= -1
        else:
            self.monitor_op = gt

    def on_begin(self, data: Data) -> None:
        self.wait = 0
        if self.baseline is not None:
            self.best = self.baseline
        else:
            self.best = np.Inf if self.monitor_op == lt else -np.Inf

    def on_epoch_end(self, data: Data) -> None:
        current = data[self.monitored_key]
        if current is None:
            return
        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.system.stop_training = True
                print("FastEstimator-EarlyStopping: '{}' triggered an early stop. Its best value was {} at epoch {}".
                      format(self.monitored_key, self.best, self.system.epoch_idx - self.wait))