Skip to content

image_saver

ImageSaver

Bases: Trace

A trace that saves images to the disk.

Parameters:

Name Type Description Default
inputs Union[str, Sequence[str]]

Key(s) of images to be saved.

required
save_dir str

The directory into which to write the images.

getcwd()
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

('eval', 'test')
Source code in fastestimator/fastestimator/trace/io/image_saver.py
@traceable()
class ImageSaver(Trace):
    """A trace that saves images to the disk.

    Args:
        inputs: Key(s) of images to be saved.
        save_dir: The directory into which to write the images.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
    """

    def __init__(self,
                 inputs: Union[str, Sequence[str]],
                 save_dir: str = os.getcwd(),
                 mode: Union[None, str, Iterable[str]] = ("eval", "test")) -> None:
        super().__init__(inputs=inputs, mode=mode)
        self.save_dir = save_dir

    def on_epoch_end(self, data: Data) -> None:
        self._save_images(data)

    def on_end(self, data: Data) -> None:
        self._save_images(data)

    def _save_images(self, data: Data):
        for key in self.inputs:
            if key in data:
                imgs = data[key]
                im_path = os.path.join(self.save_dir,
                                       "{}_{}_epoch_{}.png".format(key, self.system.mode, self.system.epoch_idx))
                if isinstance(imgs, Display):
                    imgs.show(save_path=im_path, verbose=False)
                    print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
                elif isinstance(imgs, Summary):
                    visualize_logs([imgs], save_path=im_path, verbose=False)
                    print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
                elif isinstance(imgs, (list, tuple)) and all([isinstance(img, Summary) for img in imgs]):
                    visualize_logs(imgs, save_path=im_path, verbose=False)
                    print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
                else:
                    for idx, img in enumerate(imgs):
                        f = ImageDisplay(image=img, title=key)
                        im_path = os.path.join(
                            self.save_dir,
                            "{}_{}_epoch_{}_elem_{}.png".format(key, self.system.mode, self.system.epoch_idx, idx))
                        f.show(save_path=im_path, verbose=False)
                        print("FastEstimator-ImageSaver: saved image to {}".format(im_path))