Skip to content

instance_tracker

InstanceTracker

Bases: Trace

A Trace to track metrics by instances, for example per-instance loss over time during training.

Use this in conjunction with ImageViewer or ImageSaver to see the graph at training end. This also automatically integrates with Traceability reports.

Parameters:

Name Type Description Default
index str

A key containing data indices to track over time.

required
metric str

The key of the metric by which to score data.

required
n_max_to_keep int

At the end of training, the n samples with the highest metric score will be plotted.

5
n_min_to_keep int

At the end of training, the n samples with the lowest metric score will be plotted.

5
list_to_keep Optional[Iterable[Any]]

A list of particular indices to pay attention to. This can be used in addition to n_max_to_keep and/or n_min_to_keep, or set those to zero to only track specific indices.

None
epoch_frequency int

How frequently to collect data. Increase this value to reduce ram consumption.

1
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

'eval'
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
outputs Optional[str]

The name of the output which will be generated by this trace at the end of training. If None then it will default to "by".

None

Raises:

Type Description
ValueError

If n_max_to_keep or n_min_to_keep are invalid.

Source code in fastestimator/fastestimator/trace/xai/instance_tracker.py
@traceable()
class InstanceTracker(Trace):
    """A Trace to track metrics by instances, for example per-instance loss over time during training.

    Use this in conjunction with ImageViewer or ImageSaver to see the graph at training end. This also automatically
    integrates with Traceability reports.

    Args:
        index: A key containing data indices to track over time.
        metric: The key of the metric by which to score data.
        n_max_to_keep: At the end of training, the n samples with the highest metric score will be plotted.
        n_min_to_keep: At the end of training, the n samples with the lowest metric score will be plotted.
        list_to_keep: A list of particular indices to pay attention to. This can be used in addition to `n_max_to_keep`
            and/or `n_min_to_keep`, or set those to zero to only track specific indices.
        epoch_frequency: How frequently to collect data. Increase this value to reduce ram consumption.
        mode: What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
        outputs: The name of the output which will be generated by this trace at the end of training. If None then it
            will default to "<metric>_by_<index>".

    Raises:
        ValueError: If `n_max_to_keep` or `n_min_to_keep` are invalid.
    """
    def __init__(self,
                 index: str,
                 metric: str,
                 n_max_to_keep: int = 5,
                 n_min_to_keep: int = 5,
                 list_to_keep: Optional[Iterable[Any]] = None,
                 epoch_frequency: int = 1,
                 mode: Union[None, str, Iterable[str]] = "eval",
                 ds_id: Union[None, str, Iterable[str]] = None,
                 outputs: Optional[str] = None):
        # TODO - highlight 'interesting' samples (sudden changes in relative ordering?)
        super().__init__(inputs=[index, metric], outputs=outputs or f"{metric}_by_{index}", mode=mode, ds_id=ds_id)
        self.points = []
        if n_max_to_keep < 0:
            raise ValueError(f"n_max_to_keep must be non-negative, but got {n_max_to_keep}")
        self.n_max_to_keep = n_max_to_keep
        if n_min_to_keep < 0:
            raise ValueError(f"n_min_to_keep must be non-negative, but got {n_min_to_keep}")
        self.n_min_to_keep = n_min_to_keep
        self.idx_to_keep = to_set(list_to_keep)
        # Ideally the step and metric would be separated to save space, but a given idx may not appear each epoch
        self.index_history = defaultdict(lambda: defaultdict(list))  # {mode: {idx: [(step, metric)]}}
        self.epoch_frequency = epoch_frequency

    @property
    def index_key(self) -> str:
        return self.inputs[0]

    @property
    def metric_key(self) -> str:
        return self.inputs[1]

    def on_batch_end(self, data: Data) -> None:
        if self.system.epoch_idx % self.epoch_frequency == 0:
            self.points.append((to_number(data[self.index_key]), to_number(data[self.metric_key])))

    def on_epoch_end(self, data: Data) -> None:
        if self.system.epoch_idx % self.epoch_frequency == 0:
            idx_scores = {}
            for batch in self.points:
                for idx, metric in ((batch[0][i], batch[1][i]) for i in range(len(batch[0]))):
                    idx_scores[idx.item()] = metric.item()
            for idx, metric in idx_scores.items():
                if self.idx_to_keep and self.n_min_to_keep == 0 and self.n_max_to_keep == 0:
                    # We can only skip recording if max_to_keep and min_to_keep are 0 since otherwise we don't know
                    # which histories will need to be thrown out later.
                    if idx not in self.idx_to_keep:
                        # Skip labels which the user does not want to inspect
                        continue
                self.index_history[self.system.mode][idx].append((self.system.global_step, metric))
        self.points = []

    def on_end(self, data: Data) -> None:
        index_summaries = DefaultKeyDict(default=lambda x: Summary(name=x))
        for mode in self.mode:
            final_scores = sorted([(idx, elem[-1][1]) for idx, elem in self.index_history[mode].items()],
                                  key=lambda x: x[1])
            max_idx_list = {elem[0] for elem in final_scores[-1:-self.n_max_to_keep - 1:-1]}
            min_idx_list = {elem[0] for elem in final_scores[:self.n_min_to_keep]}
            target_idx_list = Set.union(min_idx_list, max_idx_list, self.idx_to_keep)
            for idx in target_idx_list:
                for step, score in self.index_history[mode][idx]:
                    index_summaries[idx].history[mode][self.metric_key][step] = score
        self.system.add_graph(self.outputs[0], list(index_summaries.values()))  # So traceability can draw it
        data.write_without_log(self.outputs[0], list(index_summaries.values()))

    def __getstate__(self) -> Dict[str, Any]:
        """Get a representation of the state of this object.

        This method is invoked by pickle.

        Returns:
            The information to be recorded by a pickle summary of this object.
        """
        state = self.__dict__.copy()
        state['index_history'] = dict(state['index_history'])
        return state

    def __setstate__(self, state: Dict[str, Any]) -> None:
        """Set this objects internal state from a dictionary of variables.

        This method is invoked by pickle.

        Args:
            state: The saved state to be used by this object.
        """
        index_history = defaultdict(lambda: defaultdict(list))
        index_history.update(state.get('index_history', {}))
        state['index_history'] = index_history
        self.__dict__.update(state)