Skip to content

terminate_on_nan

TerminateOnNaN

Bases: Trace

End Training if a NaN value is detected.

By default (monitor_names=None) it will monitor all loss values at the end of each batch. If one or more inputs are specified, it will only monitor those values. Inputs may be loss keys and/or the keys corresponding to the outputs of other traces (ex. accuracy).

Parameters:

Name Type Description Default
monitor_names Union[None, str, Iterable[str]]

key(s) to monitor for NaN values. If None, all loss values will be monitored. "*" will monitor all trace output keys and losses.

None
mode Union[None, str, Set[str]]

What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

None
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
Source code in fastestimator/fastestimator/trace/adapt/terminate_on_nan.py
@traceable()
class TerminateOnNaN(Trace):
    """End Training if a NaN value is detected.

    By default (monitor_names=None) it will monitor all loss values at the end of each batch. If one or more inputs are
    specified, it will only monitor those values. Inputs may be loss keys and/or the keys corresponding to the outputs
    of other traces (ex. accuracy).

    Args:
        monitor_names: key(s) to monitor for NaN values. If None, all loss values will be monitored. "*" will monitor
            all trace output keys and losses.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
    """
    def __init__(
        self,
        monitor_names: Union[None, str, Iterable[str]] = None,
        mode: Union[None, str, Set[str]] = None,
        ds_id: Union[None, str, Iterable[str]] = None,
    ) -> None:
        super().__init__(inputs=monitor_names, mode=mode, ds_id=ds_id)
        self.monitor_keys = {}
        self.in_list = True

    def on_epoch_begin(self, data: Data) -> None:
        if not self.inputs:
            self.monitor_keys = self.system.network.get_loss_keys()
        elif "*" in self.inputs:
            self.monitor_keys = self.system.network.get_loss_keys()
            for trace in get_current_items(self.system.traces, run_modes=self.system.mode, epoch=self.system.epoch_idx):
                self.monitor_keys.update(trace.outputs)
        else:
            self.monitor_keys = self.inputs

    def on_batch_end(self, data: Data) -> None:
        for key in self.monitor_keys:
            if key in data:
                if check_nan(data[key]):
                    self.system.stop_training = True
                    print("FastEstimator-TerminateOnNaN: NaN Detected in: {}".format(key))

    def on_epoch_end(self, data: Data) -> None:
        for key in self.monitor_keys:
            if key in data:
                if check_nan(data[key]):
                    self.system.stop_training = True
                    print("FastEstimator-TerminateOnNaN: NaN Detected in: {}".format(key))