Skip to content

model_saver

ModelSaver

Bases: Trace

Save model weights based on epoch frequency during training.

Parameters:

Name Type Description Default
model Union[Model, Module]

A model instance compiled with fe.build.

required
save_dir str

Folder path into which to save the model.

required
frequency int

Model saving frequency in epoch(s).

1
max_to_keep Optional[int]

Maximum number of latest saved files to keep. If 0 or None, all models will be saved.

None
save_architecture bool

Whether to save the full model architecture in addition to the model weights. This option is only available for TensorFlow models at present, and will generate a folder containing several files. The model can then be re-instantiated even without access to the original code by calling: tf.keras.models.load_model().

False

Raises:

Type Description
ValueError

If max_to_keep is negative, or if save_architecture is used with a PyTorch model.

Source code in fastestimator/fastestimator/trace/io/model_saver.py
@traceable()
class ModelSaver(Trace):
    """Save model weights based on epoch frequency during training.

    Args:
        model: A model instance compiled with fe.build.
        save_dir: Folder path into which to save the `model`.
        frequency: Model saving frequency in epoch(s).
        max_to_keep: Maximum number of latest saved files to keep. If 0 or None, all models will be saved.
        save_architecture: Whether to save the full model architecture in addition to the model weights. This option is
            only available for TensorFlow models at present, and will generate a folder containing several files. The
            model can then be re-instantiated even without access to the original code by calling:
            tf.keras.models.load_model(<path to model folder>).

    Raises:
        ValueError: If `max_to_keep` is negative, or if save_architecture is used with a PyTorch model.
    """
    def __init__(self,
                 model: Union[tf.keras.Model, torch.nn.Module],
                 save_dir: str,
                 frequency: int = 1,
                 max_to_keep: Optional[int] = None,
                 save_architecture: bool = False) -> None:
        super().__init__(mode="train")
        self.model = model
        self.save_dir = save_dir
        self.frequency = frequency
        self.save_architecture = save_architecture
        if save_architecture and isinstance(model, torch.nn.Module):
            raise ValueError("Sorry, architecture saving is not currently enabled for PyTorch")
        if max_to_keep is not None and max_to_keep < 0:
            raise ValueError(f"max_to_keep should be a non-negative integer, but got {max_to_keep}")
        self.file_queue = deque([None] * (max_to_keep or 0), maxlen=max_to_keep or 0)

    def on_epoch_end(self, data: Data) -> None:
        # No model will be saved when save_dir is None, which makes smoke test easier.
        if self.save_dir and self.system.epoch_idx % self.frequency == 0:
            model_name = "{}_epoch_{}".format(self.model.model_name, self.system.epoch_idx)
            model_path = save_model(model=self.model,
                                    save_dir=self.save_dir,
                                    model_name=model_name,
                                    save_architecture=self.save_architecture)
            print("FastEstimator-ModelSaver: Saved model to {}".format(model_path))
            rm_path = self.file_queue[self.file_queue.maxlen - 1] if self.file_queue.maxlen else None
            if rm_path:
                os.remove(rm_path)
                if self.save_architecture:
                    shutil.rmtree(os.path.splitext(rm_path)[0])
                print("FastEstimator-ModelSaver: Removed model {} due to file number exceeding max_to_keep".format(
                    rm_path))
            self.file_queue.appendleft(model_path)