Skip to content



Bases: Trace

A Trace to track metrics by instances, for example per-instance loss over time during training.

Use this in conjunction with ImageViewer or ImageSaver to see the graph at training end. This also automatically integrates with Traceability reports.


Name Type Description Default
index str

A key containing data indices to track over time.

metric str

The key of the metric by which to score data.

n_max_to_keep int

At the end of training, the n samples with the highest metric score will be plotted.

n_min_to_keep int

At the end of training, the n samples with the lowest metric score will be plotted.

list_to_keep Optional[Iterable[Any]]

A list of particular indices to pay attention to. This can be used in addition to n_max_to_keep and/or n_min_to_keep, or set those to zero to only track specific indices.

epoch_frequency int

How frequently to collect data. Increase this value to reduce ram consumption.

mode Union[None, str, Iterable[str]]

What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

outputs Optional[str]

The name of the output which will be generated by this trace at the end of training. If None then it will default to "by".



Type Description

If n_max_to_keep or n_min_to_keep are invalid.

Source code in fastestimator/fastestimator/trace/xai/
class InstanceTracker(Trace):
    """A Trace to track metrics by instances, for example per-instance loss over time during training.

    Use this in conjunction with ImageViewer or ImageSaver to see the graph at training end. This also automatically
    integrates with Traceability reports.

        index: A key containing data indices to track over time.
        metric: The key of the metric by which to score data.
        n_max_to_keep: At the end of training, the n samples with the highest metric score will be plotted.
        n_min_to_keep: At the end of training, the n samples with the lowest metric score will be plotted.
        list_to_keep: A list of particular indices to pay attention to. This can be used in addition to `n_max_to_keep`
            and/or `n_min_to_keep`, or set those to zero to only track specific indices.
        epoch_frequency: How frequently to collect data. Increase this value to reduce ram consumption.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
        outputs: The name of the output which will be generated by this trace at the end of training. If None then it
            will default to "<metric>_by_<index>".

        ValueError: If `n_max_to_keep` or `n_min_to_keep` are invalid.
    def __init__(self,
                 index: str,
                 metric: str,
                 n_max_to_keep: int = 5,
                 n_min_to_keep: int = 5,
                 list_to_keep: Optional[Iterable[Any]] = None,
                 epoch_frequency: int = 1,
                 mode: Union[None, str, Iterable[str]] = "eval",
                 ds_id: Union[None, str, Iterable[str]] = None,
                 outputs: Optional[str] = None):
        # TODO - highlight 'interesting' samples (sudden changes in relative ordering?)
        super().__init__(inputs=[index, metric], outputs=outputs or f"{metric}_by_{index}", mode=mode, ds_id=ds_id)
        self.points = []
        if n_max_to_keep < 0:
            raise ValueError(f"n_max_to_keep must be non-negative, but got {n_max_to_keep}")
        self.n_max_to_keep = n_max_to_keep
        if n_min_to_keep < 0:
            raise ValueError(f"n_min_to_keep must be non-negative, but got {n_min_to_keep}")
        self.n_min_to_keep = n_min_to_keep
        self.idx_to_keep = to_set(list_to_keep)
        # Ideally the step and metric would be separated to save space, but a given idx may not appear each epoch
        self.index_history = defaultdict(lambda: defaultdict(list))  # {mode: {idx: [(step, metric)]}}
        self.epoch_frequency = epoch_frequency

    def index_key(self) -> str:
        return self.inputs[0]

    def metric_key(self) -> str:
        return self.inputs[1]

    def on_batch_end(self, data: Data) -> None:
        if self.system.epoch_idx % self.epoch_frequency == 0:
            self.points.append((to_number(data[self.index_key]), to_number(data[self.metric_key])))

    def on_epoch_end(self, data: Data) -> None:
        if self.system.epoch_idx % self.epoch_frequency == 0:
            idx_scores = {}
            for batch in self.points:
                for idx, metric in ((batch[0][i], batch[1][i]) for i in range(len(batch[0]))):
                    idx_scores[idx.item()] = metric.item()
            for idx, metric in idx_scores.items():
                if self.idx_to_keep and self.n_min_to_keep == 0 and self.n_max_to_keep == 0:
                    # We can only skip recording if max_to_keep and min_to_keep are 0 since otherwise we don't know
                    # which histories will need to be thrown out later.
                    if idx not in self.idx_to_keep:
                        # Skip labels which the user does not want to inspect
                self.index_history[self.system.mode][idx].append((self.system.global_step, metric))
        self.points = []

    def on_end(self, data: Data) -> None:
        index_summaries = DefaultKeyDict(default=lambda x: Summary(name=x))
        for mode in self.mode:
            final_scores = sorted([(idx, elem[-1][1]) for idx, elem in self.index_history[mode].items()],
                                  key=lambda x: x[1])
            max_idx_list = {elem[0] for elem in final_scores[-1:-self.n_max_to_keep - 1:-1]}
            min_idx_list = {elem[0] for elem in final_scores[:self.n_min_to_keep]}
            target_idx_list = Set.union(min_idx_list, max_idx_list, self.idx_to_keep)
            for idx in target_idx_list:
                for step, score in self.index_history[mode][idx]:
                    index_summaries[idx].history[mode][self.metric_key][step] = score
        self.system.add_graph(self.outputs[0], list(index_summaries.values()))  # So traceability can draw it
        data.write_without_log(self.outputs[0], list(index_summaries.values()))

    def __getstate__(self) -> Dict[str, Any]:
        """Get a representation of the state of this object.

        This method is invoked by pickle.

            The information to be recorded by a pickle summary of this object.
        state = self.__dict__.copy()
        state['index_history'] = dict(state['index_history'])
        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        """Set this objects internal state from a dictionary of variables.

        This method is invoked by pickle.

            state: The saved state to be used by this object.
        index_history = defaultdict(lambda: defaultdict(list))
        index_history.update(state.get('index_history', {}))
        state['index_history'] = index_history