Skip to content

tensorboard

TensorBoard

Bases: Trace

Output data for use in TensorBoard.

Note that if you plan to run a tensorboard server simultaneous to training, you may want to consider using the --reload_multifile=true flag until their multi-writer use case is finished: https://github.com/tensorflow/tensorboard/issues/1063

Parameters:

Name Type Description Default
log_dir str

Path of the directory where the log files to be parsed by TensorBoard should be saved.

'logs'
update_freq Union[None, int, str]

'batch', 'epoch', integer, or strings like '10s', '15e'. When using 'batch', writes the losses and metrics to TensorBoard after each batch. The same applies for 'epoch'. If using an integer, let's say 1000, the callback will write the metrics and losses to TensorBoard every 1000 samples. You can also use strings like '8s' to indicate every 8 steps or '5e' to indicate every 5 epochs. Note that writing too frequently to TensorBoard can slow down your training. You can use None to disable updating, but this will make the trace mostly useless.

100
write_graph bool

Whether to visualize the graph in TensorBoard. The log file can become quite large when write_graph is set to True.

True
write_images Union[None, str, List[str]]

If a string or list of strings is provided, the corresponding keys will be written to TensorBoard images.

None
weight_histogram_freq Union[None, int, str]

Frequency (in epochs) at which to compute activation and weight histograms for the layers of the model. Same argument format as update_freq.

None
paint_weights bool

If True the system will attempt to visualize model weights as an image.

False
write_embeddings Union[None, str, List[str]]

If a string or list of strings is provided, the corresponding keys will be written to TensorBoard embeddings.

None
embedding_labels Union[None, str, List[str]]

Keys corresponding to label information for the write_embeddings.

None
embedding_images Union[None, str, List[str]]

Keys corresponding to raw images to be associated with the write_embeddings.

None
Source code in fastestimator/fastestimator/trace/io/tensorboard.py
@traceable()
class TensorBoard(Trace):
    """Output data for use in TensorBoard.

    Note that if you plan to run a tensorboard server simultaneous to training, you may want to consider using the
    --reload_multifile=true flag until their multi-writer use case is finished:
    https://github.com/tensorflow/tensorboard/issues/1063

    Args:
        log_dir: Path of the directory where the log files to be parsed by TensorBoard should be saved.
        update_freq: 'batch', 'epoch', integer, or strings like '10s', '15e'. When using 'batch', writes the losses and
            metrics to TensorBoard after each batch. The same applies for 'epoch'. If using an integer, let's say 1000,
            the callback will write the metrics and losses to TensorBoard every 1000 samples. You can also use strings
            like '8s' to indicate every 8 steps or '5e' to indicate every 5 epochs. Note that writing too frequently to
            TensorBoard can slow down your training. You can use None to disable updating, but this will make the trace
            mostly useless.
        write_graph: Whether to visualize the graph in TensorBoard. The log file can become quite large when write_graph
            is set to True.
        write_images: If a string or list of strings is provided, the corresponding keys will be written to TensorBoard
            images.
        weight_histogram_freq: Frequency (in epochs) at which to compute activation and weight histograms for the layers
            of the model. Same argument format as `update_freq`.
        paint_weights: If True the system will attempt to visualize model weights as an image.
        write_embeddings: If a string or list of strings is provided, the corresponding keys will be written to
            TensorBoard embeddings.
        embedding_labels: Keys corresponding to label information for the `write_embeddings`.
        embedding_images: Keys corresponding to raw images to be associated with the `write_embeddings`.
    """
    writer: _BaseWriter

    # TODO - support for per-instance tracking

    def __init__(self,
                 log_dir: str = 'logs',
                 update_freq: Union[None, int, str] = 100,
                 write_graph: bool = True,
                 write_images: Union[None, str, List[str]] = None,
                 weight_histogram_freq: Union[None, int, str] = None,
                 paint_weights: bool = False,
                 embedding_freq: Union[None, int, str] = 'epoch',
                 write_embeddings: Union[None, str, List[str]] = None,
                 embedding_labels: Union[None, str, List[str]] = None,
                 embedding_images: Union[None, str, List[str]] = None) -> None:
        super().__init__(inputs=["*"] + to_list(write_images) + to_list(write_embeddings) + to_list(embedding_labels) +
                         to_list(embedding_images))
        self.root_log_dir = log_dir
        self.update_freq = parse_freq(update_freq)
        self.write_graph = write_graph
        self.painted_graphs = set()
        self.write_images = to_set(write_images)
        self.histogram_freq = parse_freq(weight_histogram_freq)
        if paint_weights and self.histogram_freq.freq == 0:
            self.histogram_freq.is_step = False
            self.histogram_freq.freq = 1
        self.paint_weights = paint_weights
        if write_embeddings is None and embedding_labels is None and embedding_images is None:
            # Speed up if-check short-circuiting later
            embedding_freq = None
        self.embedding_freq = parse_freq(embedding_freq)
        write_embeddings = to_list(write_embeddings)
        embedding_labels = to_list(embedding_labels)
        if embedding_labels:
            assert len(embedding_labels) == len(write_embeddings), \
                f"Expected {len(write_embeddings)} embedding_labels keys, but recieved {len(embedding_labels)}. Use \
                None to pad out the list if you have labels for only a subset of all embeddings."

        else:
            embedding_labels = [None for _ in range(len(write_embeddings))]
        embedding_images = to_list(embedding_images)
        if embedding_images:
            assert len(embedding_images) == len(write_embeddings), \
                f"Expected {len(write_embeddings)} embedding_images keys, but recieved {len(embedding_images)}. Use \
                None to pad out the list if you have labels for only a subset of all embeddings."

        else:
            embedding_images = [None for _ in range(len(write_embeddings))]
        self.write_embeddings = [(feature, label, img_label) for feature,
                                 label,
                                 img_label in zip(write_embeddings, embedding_labels, embedding_images)]
        self.collected_embeddings = defaultdict(list)

    def on_begin(self, data: Data) -> None:
        print("FastEstimator-Tensorboard: writing logs to {}".format(
            os.path.abspath(os.path.join(self.root_log_dir, self.system.experiment_time))))
        self.writer = _TfWriter(self.root_log_dir, self.system.experiment_time, self.system.network) if isinstance(
            self.system.network, TFNetwork) else _TorchWriter(
                self.root_log_dir, self.system.experiment_time, self.system.network)
        if self.write_graph and self.system.global_step == 1:
            self.painted_graphs = set()

    def on_batch_end(self, data: Data) -> None:
        if self.write_graph and self.system.network.ctx_models.symmetric_difference(self.painted_graphs):
            self.writer.write_epoch_models(mode=self.system.mode, epoch=self.system.epoch_idx)
            self.painted_graphs = self.system.network.ctx_models
        # Collect embeddings if present in batch but viewing per epoch. Don't aggregate during training though
        if self.system.mode != 'train' and self.embedding_freq.freq and not self.embedding_freq.is_step and \
                self.system.epoch_idx % self.embedding_freq.freq == 0:
            for elem in self.write_embeddings:
                name, lbl, img = elem
                if name in data:
                    self.collected_embeddings[name].append((data.get(name), data.get(lbl), data.get(img)))
        # Handle embeddings if viewing per step
        if self.embedding_freq.freq and self.embedding_freq.is_step and \
                self.system.global_step % self.embedding_freq.freq == 0:
            self.writer.write_embeddings(
                mode=self.system.mode,
                step=self.system.global_step,
                embeddings=filter(
                    lambda x: x[1] is not None,
                    map(lambda t: (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])), self.write_embeddings)))
        if self.system.mode != 'train':
            return
        if self.histogram_freq.freq and self.histogram_freq.is_step and \
                self.system.global_step % self.histogram_freq.freq == 0:
            self.writer.write_weights(mode=self.system.mode,
                                      models=self.system.network.models,
                                      step=self.system.global_step,
                                      visualize=self.paint_weights)
        if self.update_freq.freq and self.update_freq.is_step and self.system.global_step % self.update_freq.freq == 0:
            self.writer.write_scalars(mode=self.system.mode,
                                      step=self.system.global_step,
                                      scalars=filter(lambda x: is_number(x[1]), data.items()))
            self.writer.write_images(
                mode=self.system.mode,
                step=self.system.global_step,
                images=filter(lambda x: x[1] is not None, map(lambda y: (y, data.get(y)), self.write_images)))

    def on_epoch_end(self, data: Data) -> None:
        if self.system.mode == 'train' and self.histogram_freq.freq and not self.histogram_freq.is_step and \
                self.system.epoch_idx % self.histogram_freq.freq == 0:
            self.writer.write_weights(mode=self.system.mode,
                                      models=self.system.network.models,
                                      step=self.system.global_step,
                                      visualize=self.paint_weights)
        # Write out any embeddings which were aggregated over batches
        for name, val_list in self.collected_embeddings.items():
            embeddings = None if any(x[0] is None for x in val_list) else concat([x[0] for x in val_list])
            labels = None if any(x[1] is None for x in val_list) else concat([x[1] for x in val_list])
            imgs = None if any(x[2] is None for x in val_list) else concat([x[2] for x in val_list])
            self.writer.write_embeddings(mode=self.system.mode,
                                         step=self.system.global_step,
                                         embeddings=[(name, embeddings, labels, imgs)])
        self.collected_embeddings.clear()
        # Get any embeddings which were generated externally on epoch end
        if self.embedding_freq.freq and (self.embedding_freq.is_step
                                         or self.system.epoch_idx % self.embedding_freq.freq == 0):
            self.writer.write_embeddings(
                mode=self.system.mode,
                step=self.system.global_step,
                embeddings=filter(
                    lambda x: x[1] is not None,
                    map(lambda t: (t[0], data.get(t[0]), data.get(t[1]), data.get(t[2])), self.write_embeddings)))
        if self.update_freq.freq and (self.update_freq.is_step or self.system.epoch_idx % self.update_freq.freq == 0):
            self.writer.write_scalars(mode=self.system.mode,
                                      step=self.system.global_step,
                                      scalars=filter(lambda x: is_number(x[1]), data.items()))
            self.writer.write_images(
                mode=self.system.mode,
                step=self.system.global_step,
                images=filter(lambda x: x[1] is not None, map(lambda y: (y, data.get(y)), self.write_images)))

    def on_end(self, data: Data) -> None:
        self.writer.close()