Skip to content

_per_ds

per_ds

A class annotation which will convert regular traces into dataset-sensitive traces.

Parameters:

Name Type Description Default
clz

The base class to be converted.

required

Returns:

Type Description

A dataset aware version of the class. Note that if the annotated class instance has a 'per_ds' member variable

which is set to False, or has outputs containing the '|' character, then a normal (non-ds-aware) instance will

be returned instead.

Source code in fastestimator/fastestimator/trace/meta/_per_ds.py
def per_ds(clz):
    """A class annotation which will convert regular traces into dataset-sensitive traces.

    Args:
        clz: The base class to be converted.

    Returns:
        A dataset aware version of the class. Note that if the annotated class instance has a 'per_ds' member variable
        which is set to False, or has outputs containing the '|' character, then a normal (non-ds-aware) instance will
        be returned instead.
    """
    @functools.wraps(clz, updated=())
    class PerDS(clz, PerDSTrace):
        def __new__(cls, *args, **kwargs):
            # We will dynamically determine whether to return a base object or a PerDS variant
            # If any of the outputs already use the | character then we cannot make this a PerDS variant
            base_obj = clz.__new__(clz)
            base_obj.__init__(*args, **kwargs)
            for output in base_obj.outputs:
                if '|' in output:
                    return base_obj
            # If the user set per_ds to False in the constructor then we will not make this a PerDS variant
            if hasattr(base_obj, 'per_ds') and base_obj.per_ds is False:
                return base_obj
            # Otherwise we are good to go with the PerDS variant
            return super().__new__(cls)

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.fe_per_ds_trace = clz.__new__(clz)
            self.fe_per_ds_trace.__init__(*args, **kwargs)

        def get_outputs(self, ds_ids: Union[None, str, List[str]]) -> List[str]:
            ids = to_list(ds_ids)
            outputs = list(self.outputs)
            for output in self.outputs:
                for ds_id in ids:
                    outputs.append(f"{output}|{ds_id}")
            return outputs

        def on_begin(self, data: Data) -> None:
            super().on_begin(data)
            self.fe_per_ds_trace.on_begin(data)

        def on_ds_begin(self, data: Data) -> None:
            if self.system.ds_id != '':
                self.fe_per_ds_trace.on_epoch_begin(DSData(self.system.ds_id, data))

        def on_batch_begin(self, data: Data) -> None:
            super().on_batch_begin(data)
            if self.system.ds_id != '':
                self.fe_per_ds_trace.on_batch_begin(DSData(self.system.ds_id, data))

        def on_batch_end(self, data: Data) -> None:
            if self.system.ds_id != '':
                self.fe_per_ds_trace.on_batch_end(DSData(self.system.ds_id, data))
                # Block the main process from writing per-instance info since we already have the more detailed key
                data.per_instance_enabled = False
            super().on_batch_end(data)
            data.per_instance_enabled = True

        def on_ds_end(self, data: Data) -> None:
            if self.system.ds_id != '':
                self.fe_per_ds_trace.on_epoch_end(DSData(self.system.ds_id, data))

        def on_end(self, data: Data) -> None:
            super().on_end(data)
            self.fe_per_ds_trace.on_end(data)

    return PerDS