Skip to content



A class which tracks state information while the fe.Estimator is running.

This class is intentionally not @traceable.


Name Type Description Default
network BaseNetwork

The network instance being used by the current fe.Estimator.

pipeline Pipeline

The pipeline instance being used by the current fe.Estimator.

traces List[Union[Trace, Scheduler[Trace]]]

The traces provided to the current fe.Estimator.

mode Optional[str]

The current execution mode (or None for warmup).

num_devices int

How many GPUs are available for training.

log_steps Optional[int]

Log every n steps (0 to disable train logging, None to disable all logging).

total_epochs int

How many epochs training is expected to run for.

train_steps_per_epoch Optional[int]

Whether training iterations will be cut short or extended to complete N steps (or use None if they will run to completion)

eval_steps_per_epoch Optional[int]

Whether evaluation iterations will be cut short or extended to complete N steps (or use None if they will run to completion)

eval_log_steps Sequence[int]

The list of steps on which evaluation progress logs need to be printed.

system_config Optional[List[FeSummaryTable]]

A description of the initialization parameters defining the associated estimator.



Name Type Description
mode Optional[str]

What is the current execution mode of the estimator ('train', 'eval', 'test'), None if warmup.

ds_id str

The current dataset id, Empty string if there is only one dataset in each mode.

exp_id int

A unique identifier for current training experiment.

global_step Optional[int]

How many training steps have elapsed.

num_devices int

How many GPUs are available for training.

log_steps Optional[int]

Log every n steps (0 to disable train logging, None to disable all logging).

total_epochs int

How many epochs training is expected to run for.

epoch_idx int

The current epoch index for the training (starting from 1).

batch_idx Optional[int]

The current batch index within an epoch (starting from 1).

stop_training bool

A flag to signal that training should abort.

network BaseNetwork

A reference to the network being used.

pipeline Pipeline

A reference to the pipeline being used.

traces List[Union[Trace, Scheduler[Trace]]]

The traces being used.

train_steps_per_epoch Optional[int]

Training will be cut short or extended to complete N steps even if loader is not yet exhausted. If None, all data will be used.

eval_steps_per_epoch Optional[int]

Evaluation will be cut short or extended to complete N steps even if loader is not yet exhausted. If None, all data will be used.

eval_log_steps_request List[int]

The list of steps on which the user wants eval log printing.

eval_log_steps Tuple[List[int], int]

The steps on which eval logs will be printed, The total number of eval steps in this epoch.

summary Summary

An object to write experiment results to.

experiment_time str

A timestamp indicating when this model was trained.

custom_graphs Dict[str, List[Summary]]

A place to store extra graphs which are too complicated for the primary history.

Source code in fastestimator/fastestimator/summary/
class System:
    """A class which tracks state information while the fe.Estimator is running.

    This class is intentionally not @traceable.

        network: The network instance being used by the current fe.Estimator.
        pipeline: The pipeline instance being used by the current fe.Estimator.
        traces: The traces provided to the current fe.Estimator.
        mode: The current execution mode (or None for warmup).
        num_devices: How many GPUs are available for training.
        log_steps: Log every n steps (0 to disable train logging, None to disable all logging).
        total_epochs: How many epochs training is expected to run for.
        train_steps_per_epoch: Whether training iterations will be cut short or extended to complete N steps (or use None if they will run
            to completion)
        eval_steps_per_epoch: Whether evaluation iterations will be cut short or extended to complete N steps (or use None if they will run
            to completion)
        eval_log_steps: The list of steps on which evaluation progress logs need to be printed.
        system_config: A description of the initialization parameters defining the associated estimator.

        mode: What is the current execution mode of the estimator ('train', 'eval', 'test'), None if warmup.
        ds_id: The current dataset id, Empty string if there is only one dataset in each mode.
        exp_id: A unique identifier for current training experiment.
        global_step: How many training steps have elapsed.
        num_devices: How many GPUs are available for training.
        log_steps: Log every n steps (0 to disable train logging, None to disable all logging).
        total_epochs: How many epochs training is expected to run for.
        epoch_idx: The current epoch index for the training (starting from 1).
        batch_idx: The current batch index within an epoch (starting from 1).
        stop_training: A flag to signal that training should abort.
        network: A reference to the network being used.
        pipeline: A reference to the pipeline being used.
        traces: The traces being used.
        train_steps_per_epoch: Training will be cut short or extended to complete N steps even if loader is not yet
            exhausted. If None, all data will be used.
        eval_steps_per_epoch: Evaluation will be cut short or extended to complete N steps even if loader is not yet
            exhausted. If None, all data will be used.
        eval_log_steps_request: The list of steps on which the user wants eval log printing.
        eval_log_steps: The steps on which eval logs will be printed, The total number of eval steps in this epoch.
        summary: An object to write experiment results to.
        experiment_time: A timestamp indicating when this model was trained.
        custom_graphs: A place to store extra graphs which are too complicated for the primary history.

    mode: Optional[str]
    ds_id: str
    exp_id: int
    global_step: Optional[int]
    num_devices: int
    log_steps: Optional[int]
    total_epochs: int
    epoch_idx: int
    batch_idx: Optional[int]
    stop_training: bool
    network: BaseNetwork
    pipeline: Pipeline
    traces: List[Union['Trace', Scheduler['Trace']]]
    train_steps_per_epoch: Optional[int]
    eval_steps_per_epoch: Optional[int]
    eval_log_steps_request: List[int]
    eval_log_steps: Tuple[List[int], int]
    summary: Summary
    experiment_time: str
    custom_graphs: Dict[str, List[Summary]]

    def __init__(self,
                 network: BaseNetwork,
                 pipeline: Pipeline,
                 traces: List[Union['Trace', Scheduler['Trace']]],
                 mode: Optional[str] = None,
                 ds_id: str = '',
                 num_devices: int = get_num_gpus(),
                 log_steps: Optional[int] = None,
                 total_epochs: int = 0,
                 train_steps_per_epoch: Optional[int] = None,
                 eval_steps_per_epoch: Optional[int] = None,
                 eval_log_steps: Sequence[int] = (),
                 system_config: Optional[List[FeSummaryTable]] = None) -> None: = network
        self.pipeline = pipeline
        self.eval_log_steps_request = to_list(eval_log_steps)
        self.eval_log_steps = ([], 0)
        self.traces = traces
        self.mode = mode
        self.ds_id = ds_id
        self.num_devices = num_devices
        self.log_steps = log_steps
        self.total_epochs = total_epochs
        self.batch_idx = None
        self.train_steps_per_epoch = train_steps_per_epoch
        self.eval_steps_per_epoch = eval_steps_per_epoch
        self.stop_training = False
        self.summary = Summary(None, system_config)
        self.experiment_time = ""
        self.custom_graphs = {}

    def steps_per_epoch(self) -> Optional[int]:
        if self.mode == 'train':
            return self.train_steps_per_epoch
        elif self.mode == 'eval':
            return self.eval_steps_per_epoch
            return None

    def _initialize_state(self) -> None:
        """Initialize the training state.
        self.global_step = None
        self.epoch_idx = 0
        # Get a 64 bit random id related to current time
        self.exp_id = int.from_bytes(uuid.uuid1().bytes, byteorder='big', signed=True) >> 64

    def update_global_step(self) -> None:
        """Increment the current `global_step`.
        if self.global_step is None:
            self.global_step = 1
            self.global_step += 1

    def update_batch_idx(self) -> None:
        """Increment the current `batch_idx`.
        if self.batch_idx is None:
            self.batch_idx = 1
            self.batch_idx += 1

    def reset(self, summary_name: Optional[str] = None, system_config: Optional[str] = None) -> None:
        """Reset the current `System` for a new round of training, including a new `Summary` object.

            summary_name: The name of the experiment. The `Summary` object will store information iff name is not None.
            system_config: A description of the initialization parameters defining the associated estimator.
        self.experiment_time ="%Y%m%d-%H%M%S")
        self.mode = "train"
        self.ds_id = ''
        self.batch_idx = None
        self.stop_training = False
        self.summary = Summary(summary_name, system_config)
        self.custom_graphs = {}

    def reset_for_test(self, summary_name: Optional[str] = None) -> None:
        """Partially reset the current `System` object for a new round of testing.

            summary_name: The name of the experiment. If not provided, the system will re-use the previous summary name.
        self.experiment_time = self.experiment_time or"%Y%m%d-%H%M%S")
        self.mode = "test"
        self.ds_id = ''
        if not self.stop_training:
            self.epoch_idx = self.total_epochs
        self.stop_training = False = summary_name or  # Keep old experiment name if new one not provided
        self.summary.history.pop('test', None)
        for graph_set in self.custom_graphs.values():
            for graph in graph_set:
                graph.history.pop('test', None)

    def write_summary(self, key: str, value: Any) -> None:
        """Write an entry into the `Summary` object (iff the experiment was named).

            key: The key to write into the summary object.
            value: The value to write into the summary object.
        if self.summary:
            self.summary.history[self.mode][key][self.global_step or 0] = value

    def add_graph(self, graph_name: str, graph: Union[Summary, List[Summary]]) -> None:
        """Write custom summary graphs into the System.

        This can be useful for things like the LabelTracker trace to interact with Traceability reports.

            graph_name: The name of the graph (so that you can override it later if desired).
            graph: The custom summary to be tracked.
        if isinstance(graph, Summary):
            self.custom_graphs[graph_name] = [graph]
            self.custom_graphs[graph_name] = list(graph)

    def save_state(self, save_dir: str) -> None:
        """Load training state.

            save_dir: The directory into which to save the state
        os.makedirs(save_dir, exist_ok=True)
        # Start with the high-level info. We could use pickle for this but having it human readable is nice.
        state = {key: value for key, value in self.__dict__.items() if is_restorable(value)[0]}
        with open(os.path.join(save_dir, 'system.json'), 'w') as fp:
            json.dump(state, fp, indent=4)
        # Save all of the models / optimizer states
        for model in
            save_model(model, save_dir=save_dir, save_optimizer=hasattr(model, "optimizer") and model.optimizer)
        # Save everything else
        objects = {
            'summary': self.summary,
            'custom_graphs': self.custom_graphs,
            'traces': [trace.__getstate__() if hasattr(trace, '__getstate__') else {} for trace in self.traces],
            'tops': [op.__getstate__() if hasattr(op, '__getstate__') else {} for op in],
            'slops': [sl.__getstate__() if hasattr(sl, '__getstate__') else {} for sl in],
            'pops': [op.__getstate__() if hasattr(op, '__getstate__') else {} for op in],
            'nops': [op.__getstate__() if hasattr(op, '__getstate__') else {} for op in self.pipeline.ops],
            'ds': {
                mode: {
                    key: value.__getstate__()
                    for key, value in ds.items() if hasattr(value, '__getstate__')
                for mode,
                ds in
        with open(os.path.join(save_dir, 'objects.pkl'), 'wb') as file:
            # We need to use a custom pickler here to handle MirroredStrategy, which will show up inside of tf
            # MirroredVariables in multi-gpu systems.
            p = pickle.Pickler(file)
            p.dispatch_table = copyreg.dispatch_table.copy()
            p.dispatch_table[MirroredStrategy] = pickle_mirroredstrategy

    def load_state(self, load_dir: str) -> None:
        """Load training state.

            load_dir: The directory from which to reload the state.

            FileNotFoundError: If necessary files can not be found.
        # Reload the high-level system information
        system_path = os.path.join(load_dir, 'system.json')
        if not os.path.exists(system_path):
            raise FileNotFoundError(f"Could not find system summary file at {system_path}")
        with open(system_path, 'r') as fp:
            state = json.load(fp)
        # Reload the models
        for model in
            self._load_model(model, load_dir)
        # Reload everything else
        objects_path = os.path.join(load_dir, 'objects.pkl')
        if not os.path.exists(objects_path):
            raise FileNotFoundError(f"Could not find the objects summary file at {objects_path}")
        with open(objects_path, 'rb') as file:
            objects = pickle.load(file)
        self.custom_graphs = objects['custom_graphs']
        self._load_list(objects, 'traces', self.traces)
        self._load_list(objects, 'tops',
        self._load_list(objects, 'slops',
        self._load_list(objects, 'pops',
        self._load_list(objects, 'nops', self.pipeline.ops)
        self._load_dict(objects, 'ds',

    def _load_model(model: Model, base_path: str) -> None:
        """Load model and optimizer weights from disk.

            model: The model to be loaded.
            base_path: The folder where the model should be located.

            ValueError: If the model is of an unknown type.
            FileNotFoundError: If the model weights or optimizer state is missing.
        if isinstance(model, tf.keras.Model):
            model_ext, optimizer_ext = 'h5', 'pkl'
        elif isinstance(model, torch.nn.Module):
            model_ext, optimizer_ext = 'pt', 'pt'
            raise ValueError(f"Unknown model type: {type(model)}")
        weights_path = os.path.join(base_path, f"{model.model_name}.{model_ext}")
        if not os.path.exists(weights_path):
            raise FileNotFoundError(f"Cannot find model weights file at {weights_path}")
        optimizer_path = os.path.join(base_path, f"{model.model_name}_opt.{optimizer_ext}")
        load_model(model, weights_path=weights_path, load_optimizer=os.path.exists(optimizer_path))

    def _load_list(states: Dict[str, Any], state_key: str, in_memory_objects: List[Any]) -> None:
        """Load a list of pickled states from the disk.

            states: The states to be restored.
            state_key: Which state to select from the dictionary.
            in_memory_objects: The existing in memory objects to be updated.

            ValueError: If the number of saved states does not match the number of in-memory objects.
        states = states[state_key]
        if not isinstance(states, list):
            raise ValueError(f"Expected {state_key} to contain a list, but found a {type(states)}")
        if len(states) != len(in_memory_objects):
            raise ValueError("Expected saved {} to contain {} objects, but found {} instead".format(
                state_key, len(in_memory_objects), len(states)))
        for obj, state in zip(in_memory_objects, states):
            if hasattr(obj, '__setstate__'):
            elif hasattr(obj, '__dict__'):
                # Might be a None or something else that can't be updated

    def _load_dict(states: Dict[str, Any], state_key: str, in_memory_objects: Dict[Any, Any]) -> None:
        """Load a dictionary of pickled states from the disk.

            states: The states to be restored.
            state_key: Which state to select from the dictionary.
            in_memory_objects: The existing in memory objects to be updated.

            ValueError: If the configuration of saved states does not match the number of in-memory objects.
            FileNotFoundError: If the desired state file cannot be found.
        states = states[state_key]
        if not isinstance(states, dict):
            raise ValueError(f"Expected {state_key} to contain a dict, but found a {type(states)}")
        # Note that not being a subset is different from being a superset
        if not states.keys() <= in_memory_objects.keys():
            raise ValueError("Saved {} contained unexpected keys: {}".format(state_key,
                                                                             states.keys() - in_memory_objects.keys()))
        for key, state in states.items():
            obj = in_memory_objects[key]
            if hasattr(obj, '__setstate__'):
            elif hasattr(obj, '__dict__'):
            elif isinstance(obj, dict):
                System._load_dict(states, key, obj)
                # Might be a None or something else that can't be updated


Write custom summary graphs into the System.

This can be useful for things like the LabelTracker trace to interact with Traceability reports.


Name Type Description Default
graph_name str

The name of the graph (so that you can override it later if desired).

graph Union[Summary, List[Summary]]

The custom summary to be tracked.

Source code in fastestimator/fastestimator/summary/
def add_graph(self, graph_name: str, graph: Union[Summary, List[Summary]]) -> None:
    """Write custom summary graphs into the System.

    This can be useful for things like the LabelTracker trace to interact with Traceability reports.

        graph_name: The name of the graph (so that you can override it later if desired).
        graph: The custom summary to be tracked.
    if isinstance(graph, Summary):
        self.custom_graphs[graph_name] = [graph]
        self.custom_graphs[graph_name] = list(graph)


Load training state.


Name Type Description Default
load_dir str

The directory from which to reload the state.



Type Description

If necessary files can not be found.

Source code in fastestimator/fastestimator/summary/
def load_state(self, load_dir: str) -> None:
    """Load training state.

        load_dir: The directory from which to reload the state.

        FileNotFoundError: If necessary files can not be found.
    # Reload the high-level system information
    system_path = os.path.join(load_dir, 'system.json')
    if not os.path.exists(system_path):
        raise FileNotFoundError(f"Could not find system summary file at {system_path}")
    with open(system_path, 'r') as fp:
        state = json.load(fp)
    # Reload the models
    for model in
        self._load_model(model, load_dir)
    # Reload everything else
    objects_path = os.path.join(load_dir, 'objects.pkl')
    if not os.path.exists(objects_path):
        raise FileNotFoundError(f"Could not find the objects summary file at {objects_path}")
    with open(objects_path, 'rb') as file:
        objects = pickle.load(file)
    self.custom_graphs = objects['custom_graphs']
    self._load_list(objects, 'traces', self.traces)
    self._load_list(objects, 'tops',
    self._load_list(objects, 'slops',
    self._load_list(objects, 'pops',
    self._load_list(objects, 'nops', self.pipeline.ops)
    self._load_dict(objects, 'ds',


Reset the current System for a new round of training, including a new Summary object.


Name Type Description Default
summary_name Optional[str]

The name of the experiment. The Summary object will store information iff name is not None.

system_config Optional[str]

A description of the initialization parameters defining the associated estimator.

Source code in fastestimator/fastestimator/summary/
def reset(self, summary_name: Optional[str] = None, system_config: Optional[str] = None) -> None:
    """Reset the current `System` for a new round of training, including a new `Summary` object.

        summary_name: The name of the experiment. The `Summary` object will store information iff name is not None.
        system_config: A description of the initialization parameters defining the associated estimator.
    self.experiment_time ="%Y%m%d-%H%M%S")
    self.mode = "train"
    self.ds_id = ''
    self.batch_idx = None
    self.stop_training = False
    self.summary = Summary(summary_name, system_config)
    self.custom_graphs = {}


Partially reset the current System object for a new round of testing.


Name Type Description Default
summary_name Optional[str]

The name of the experiment. If not provided, the system will re-use the previous summary name.

Source code in fastestimator/fastestimator/summary/
def reset_for_test(self, summary_name: Optional[str] = None) -> None:
    """Partially reset the current `System` object for a new round of testing.

        summary_name: The name of the experiment. If not provided, the system will re-use the previous summary name.
    self.experiment_time = self.experiment_time or"%Y%m%d-%H%M%S")
    self.mode = "test"
    self.ds_id = ''
    if not self.stop_training:
        self.epoch_idx = self.total_epochs
    self.stop_training = False = summary_name or  # Keep old experiment name if new one not provided
    self.summary.history.pop('test', None)
    for graph_set in self.custom_graphs.values():
        for graph in graph_set:
            graph.history.pop('test', None)


Load training state.


Name Type Description Default
save_dir str

The directory into which to save the state

Source code in fastestimator/fastestimator/summary/
def save_state(self, save_dir: str) -> None:
    """Load training state.

        save_dir: The directory into which to save the state
    os.makedirs(save_dir, exist_ok=True)
    # Start with the high-level info. We could use pickle for this but having it human readable is nice.
    state = {key: value for key, value in self.__dict__.items() if is_restorable(value)[0]}
    with open(os.path.join(save_dir, 'system.json'), 'w') as fp:
        json.dump(state, fp, indent=4)
    # Save all of the models / optimizer states
    for model in
        save_model(model, save_dir=save_dir, save_optimizer=hasattr(model, "optimizer") and model.optimizer)
    # Save everything else
    objects = {
        'summary': self.summary,
        'custom_graphs': self.custom_graphs,
        'traces': [trace.__getstate__() if hasattr(trace, '__getstate__') else {} for trace in self.traces],
        'tops': [op.__getstate__() if hasattr(op, '__getstate__') else {} for op in],
        'slops': [sl.__getstate__() if hasattr(sl, '__getstate__') else {} for sl in],
        'pops': [op.__getstate__() if hasattr(op, '__getstate__') else {} for op in],
        'nops': [op.__getstate__() if hasattr(op, '__getstate__') else {} for op in self.pipeline.ops],
        'ds': {
            mode: {
                key: value.__getstate__()
                for key, value in ds.items() if hasattr(value, '__getstate__')
            for mode,
            ds in
    with open(os.path.join(save_dir, 'objects.pkl'), 'wb') as file:
        # We need to use a custom pickler here to handle MirroredStrategy, which will show up inside of tf
        # MirroredVariables in multi-gpu systems.
        p = pickle.Pickler(file)
        p.dispatch_table = copyreg.dispatch_table.copy()
        p.dispatch_table[MirroredStrategy] = pickle_mirroredstrategy


Increment the current batch_idx.

Source code in fastestimator/fastestimator/summary/
def update_batch_idx(self) -> None:
    """Increment the current `batch_idx`.
    if self.batch_idx is None:
        self.batch_idx = 1
        self.batch_idx += 1


Increment the current global_step.

Source code in fastestimator/fastestimator/summary/
def update_global_step(self) -> None:
    """Increment the current `global_step`.
    if self.global_step is None:
        self.global_step = 1
        self.global_step += 1


Write an entry into the Summary object (iff the experiment was named).


Name Type Description Default
key str

The key to write into the summary object.

value Any

The value to write into the summary object.

Source code in fastestimator/fastestimator/summary/
def write_summary(self, key: str, value: Any) -> None:
    """Write an entry into the `Summary` object (iff the experiment was named).

        key: The key to write into the summary object.
        value: The value to write into the summary object.
    if self.summary:
        self.summary.history[self.mode][key][self.global_step or 0] = value


A custom reduce function to use when Pickle encounters a tf MirroredStrategy.

This relies on the fact that the tf strategy will already be set before the System.load_state method gets called.


Name Type Description Default
obj MirroredStrategy

The MirroredStrategy instance.



Type Description
Tuple[Callable, Tuple]

The mechanism to construct a new instance of the MirroredStrategy. See Python docs on the reduce method.

Source code in fastestimator/fastestimator/summary/
def pickle_mirroredstrategy(obj: MirroredStrategy) -> Tuple[Callable, Tuple]:
    """A custom reduce function to use when Pickle encounters a tf MirroredStrategy.

    This relies on the fact that the tf strategy will already be set before the System.load_state method gets called.

        obj: The MirroredStrategy instance.

        The mechanism to construct a new instance of the MirroredStrategy. See Python docs on the __reduce__ method.
    return tf.distribute.get_strategy, ()