Skip to content

op_dataset

OpDataset

Bases: Dataset

A wrapper for datasets which allows operators to be applied to them in a pipeline.

This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets within an Op dataset as needed.

Parameters:

Name Type Description Default
dataset Dataset

The base dataset to wrap.

required
ops List[NumpyOp]

A list of ops to be applied after the base dataset __getitem__ is invoked.

required
mode str

What mode the system is currently running in ('train', 'eval', 'test', or 'infer').

required
output_keys Optional[Set[str]]

What keys can be produced from pipeline. If None or empty, all keys will be considered.

None
deep_remainder bool

Whether data which is not modified by Ops should be deep copied or not. This argument is used to help with RAM management, but end users can almost certainly ignore it.

True
Source code in fastestimator/fastestimator/dataset/op_dataset.py
@traceable(blacklist='lock')
class OpDataset(Dataset):
    """A wrapper for datasets which allows operators to be applied to them in a pipeline.

    This class should not be directly instantiated by the end user. The fe.Pipeline will automatically wrap datasets
    within an Op dataset as needed.

    Args:
        dataset: The base dataset to wrap.
        ops: A list of ops to be applied after the base `dataset` `__getitem__` is invoked.
        mode: What mode the system is currently running in ('train', 'eval', 'test', or 'infer').
        output_keys: What keys can be produced from pipeline. If None or empty, all keys will be considered.
        deep_remainder: Whether data which is not modified by Ops should be deep copied or not. This argument is used to
            help with RAM management, but end users can almost certainly ignore it.
    """
    def __init__(self,
                 dataset: Dataset,
                 ops: List[NumpyOp],
                 mode: str,
                 output_keys: Optional[Set[str]] = None,
                 deep_remainder: bool = True) -> None:
        # Track whether this dataset returns batches or not (useful for pipeline and traceability)
        if not hasattr(dataset, "fe_batch"):
            sample_item = dataset[0]
            dataset.fe_batch = len(sample_item) if isinstance(sample_item, list) else 0
        self.dataset = dataset
        self.fe_batch = dataset.fe_batch
        if hasattr(dataset, "fe_reset_ds"):
            self.fe_reset_ds = dataset.fe_reset_ds
        if hasattr(dataset, "fe_batch_indices"):
            self.fe_batch_indices = dataset.fe_batch_indices
        self.ops = ops
        self.mode = mode
        self.output_keys = output_keys
        self.deep_remainder = deep_remainder
        self.lock = Lock()
        self.to_warn: Set[str] = set()
        if not hasattr(OpDataset, 'warned'):
            # Declaring this outside the init would trigger mac multi-processing to pick a non-fork start method
            OpDataset.warned = Array(ctypes.c_char, 200, lock=False)

    def __getitem__(self, index: int) -> Union[Dict[str, Any], List[Dict[str, Any]], FilteredData]:
        """Fetch a data instance at a specified index, and apply transformations to it.

        Args:
            index: Which datapoint to retrieve.

        Returns:
            The data dictionary from the specified index, with transformations applied OR an indication that this index
            should be thrown out.
        """
        item = self.dataset[index]
        if isinstance(item, list):
            # BatchDataset may randomly sample the same elements multiple times, so need to avoid reprocessing
            unique_samples = {}  # id: idx
            results = []
            for idx, data in enumerate(item):
                data_id = id(data)
                if data_id not in unique_samples:
                    data = _DelayedDeepDict(data)
                    filter_data = forward_numpyop(self.ops, data, {'mode': self.mode})
                    if filter_data:
                        results.append(filter_data)
                    else:
                        data.finalize(retain=self.output_keys, deep_remainder=self.deep_remainder)
                        if data.warn:
                            self.to_warn |= data.to_warn
                        results.append(data.as_dict())
                    unique_samples[data_id] = idx
                else:
                    results.append(results[unique_samples[data_id]])
        else:
            results = _DelayedDeepDict(item)
            filter_data = forward_numpyop(self.ops, results, {'mode': self.mode})
            if filter_data:
                return filter_data
            results.finalize(retain=self.output_keys, deep_remainder=self.deep_remainder)
            if results.warn:
                self.to_warn |= results.to_warn
            results = results.as_dict()
        if self.to_warn and self.lock.acquire(block=False):
            self.handle_warning(self.to_warn)
            self.to_warn.clear()
            # We intentionally never release the lock so that during multi-threading only 1 message can be printed
        return results

    @classmethod
    def handle_warning(cls, candidates: Set[str]) -> None:
        """A function which prints warning messages about unused keys if such messages haven't already been printed.

        Args:
            candidates: Unused keys which you might need to print a warning message about.
        """
        if not candidates:
            return
        # Keys can't contain the ":" or ";" character due to check_io_names base_util function
        warned = set((str(cls.warned.value, 'utf8') or "").split(":"))
        if ";" not in warned:
            # We use ; as a special character to indicate that the warned buffer was overflowed by too many keys
            to_warn = candidates - warned
            if to_warn:
                warn("The following key(s) are being pruned since they are unused outside of the "
                     "Pipeline. To prevent this, you can declare the key(s) as inputs to Traces or TensorOps: "
                     f"{', '.join(humansorted(to_warn))}")
                warned |= to_warn
                warned = bytes(":".join(warned), 'utf8')
                if len(warned) > 198:
                    # This would overflow the warning buffer, so disable the warning mechanism in the future
                    # Note that the warning will still happen the first time the overly-long keys appear.
                    # 198 rather than 199 to allow for a null terminator at the end of the array.
                    warn("Any further key pruning warnings in subsequent epochs will not be printed.")
                    warned = bytes(";", 'utf8')
                cls.warned.value = warned

    def __len__(self):
        return len(self.dataset)

handle_warning classmethod

A function which prints warning messages about unused keys if such messages haven't already been printed.

Parameters:

Name Type Description Default
candidates Set[str]

Unused keys which you might need to print a warning message about.

required
Source code in fastestimator/fastestimator/dataset/op_dataset.py
@classmethod
def handle_warning(cls, candidates: Set[str]) -> None:
    """A function which prints warning messages about unused keys if such messages haven't already been printed.

    Args:
        candidates: Unused keys which you might need to print a warning message about.
    """
    if not candidates:
        return
    # Keys can't contain the ":" or ";" character due to check_io_names base_util function
    warned = set((str(cls.warned.value, 'utf8') or "").split(":"))
    if ";" not in warned:
        # We use ; as a special character to indicate that the warned buffer was overflowed by too many keys
        to_warn = candidates - warned
        if to_warn:
            warn("The following key(s) are being pruned since they are unused outside of the "
                 "Pipeline. To prevent this, you can declare the key(s) as inputs to Traces or TensorOps: "
                 f"{', '.join(humansorted(to_warn))}")
            warned |= to_warn
            warned = bytes(":".join(warned), 'utf8')
            if len(warned) > 198:
                # This would overflow the warning buffer, so disable the warning mechanism in the future
                # Note that the warning will still happen the first time the overly-long keys appear.
                # 198 rather than 199 to allow for a null terminator at the end of the array.
                warn("Any further key pruning warnings in subsequent epochs will not be printed.")
                warned = bytes(";", 'utf8')
            cls.warned.value = warned