Skip to content

dataloader

FEDataLoader

Bases: DataLoader

A Data Loader that can handle filtering data.

This class is intentionally not @traceable.

Parameters:

Name Type Description Default
dataset MapDataset

The dataset to be drawn from. The dataset may optionally implement .fe_reset_ds(bool) and/or .fe_batch_indices(int) methods to modify the system's sampling behavior. See fe.dataset.BatchDataset for an example which uses both of these methods.

required
postprocess_fn Optional[PostProcessFunction]

A function to run on a collated batch of data before returning it. This function can return a FilteredData object in order to drop the given batch.

None
batch_size Optional[int]

The batch size to use (or None if the dataset is already providing a batch).

1
steps_per_epoch Optional[int]

How many steps to have per epoch. If None the loader will perform a single pass through the dataset (unless samples are filtered with replacement, in which case the dataset may be passed over multiple times). If steps_per_epoch is set, it will truncate or expand the dataset until the specified number of steps are reached. When expanding datasets, they will be exhausted in their entirety before being re-sampled, equivalent to running multiple epochs of training one after the other (unless you are also filtering data, in which case at most one batch of data might be seen after the re-shuffling occurs).

None
shuffle bool

Whether to shuffle the dataset.

False
num_workers int

How many multiprocessing threads to use (unix/mac only).

0
collate_fn Optional[Callable]

What function to use to collate a list of data into a batch. This should take care of any desired padding.

None
drop_last bool

Whether to drop the last batch of data if that batch is incomplete. Note that this is meaningless for batched datasets, as well as when steps_per_epoch is set - in which case the dataset will be re-sampled as necessary until the specified number of steps has been completed in full.

False
Source code in fastestimator/fastestimator/dataset/dataloader.py
class FEDataLoader(DataLoader):
    """A Data Loader that can handle filtering data.

    This class is intentionally not @traceable.

    Args:
        dataset: The dataset to be drawn from. The dataset may optionally implement .fe_reset_ds(bool) and/or
            .fe_batch_indices(int) methods to modify the system's sampling behavior. See fe.dataset.BatchDataset for an
            example which uses both of these methods.
        postprocess_fn: A function to run on a collated batch of data before returning it. This function can return a
            FilteredData object in order to drop the given batch.
        batch_size: The batch size to use (or None if the dataset is already providing a batch).
        steps_per_epoch: How many steps to have per epoch. If None the loader will perform a single pass through the
            dataset (unless samples are filtered with replacement, in which case the dataset may be passed over multiple
            times). If `steps_per_epoch` is set, it will truncate or expand the dataset until the specified number of
            steps are reached. When expanding datasets, they will be exhausted in their entirety before being
            re-sampled, equivalent to running multiple epochs of training one after the other (unless you are also
            filtering data, in which case at most one batch of data might be seen after the re-shuffling occurs).
        shuffle: Whether to shuffle the dataset.
        num_workers: How many multiprocessing threads to use (unix/mac only).
        collate_fn: What function to use to collate a list of data into a batch. This should take care of any desired
            padding.
        drop_last: Whether to drop the last batch of data if that batch is incomplete. Note that this is meaningless for
            batched datasets, as well as when `steps_per_epoch` is set - in which case the dataset will be re-sampled as
            necessary until the specified number of steps has been completed in full.
    """
    _current_threads = []
    FE_LOADER_KIND = 7

    # The typing for 'dataset' should be an 'and' rather than 'or' but that feature is still under development:
    # https://github.com/python/typing/issues/213

    def __init__(self,
                 dataset: MapDataset,
                 postprocess_fn: Optional[PostProcessFunction] = None,
                 batch_size: Optional[int] = 1,
                 steps_per_epoch: Optional[int] = None,
                 shuffle: bool = False,
                 num_workers: int = 0,
                 collate_fn: Optional[Callable] = None,
                 drop_last: bool = False):
        reset_fn = dataset.fe_reset_ds if hasattr(dataset, 'fe_reset_ds') else None
        convert_fn = dataset.fe_batch_indices if hasattr(dataset, 'fe_batch_indices') else None
        sampler = InfiniteSampler(data_source=dataset, shuffle=shuffle, reset_fn=reset_fn, convert_fn=convert_fn)
        if batch_size is not None and batch_size < 1:
            raise ValueError(f"batch_size must be None or a positive integer, but got {batch_size}")
        # Figure out the real batch size. This is already done in OpDataset, but if user manually instantiates this
        # loader without using an OpDataset we still want to know the batch size
        if not hasattr(dataset, "fe_batch"):
            sample_item = dataset[0]
            dataset.fe_batch = len(sample_item) if isinstance(sample_item, list) else 0
        if dataset.fe_batch:
            # The batch size where torch is concerned is probably None, but we know that it is secretly batched
            self.fe_batch_size = dataset.fe_batch
        else:
            self.fe_batch_size = batch_size
        # Figure out how many samples should be returned during the course of 1 epoch
        if steps_per_epoch is not None:
            to_yield = steps_per_epoch * (batch_size or 1)
            # Note that drop_last is meaningless here since we will provide exactly the requested number of steps
        else:
            if isinstance(dataset, OpDataset) and isinstance(dataset.dataset, ExtendDataset):
                to_yield = dataset.dataset.spoof_length
            elif isinstance(dataset, ExtendDataset):
                to_yield = dataset.spoof_length
            else:
                to_yield = len(dataset)
            if drop_last:
                to_yield -= to_yield % (batch_size or 1)
        self.fe_samples_to_yield = to_yield
        self.fe_drop_last = drop_last
        self.fe_collate_fn = collate_fn or default_collate
        if self.fe_batch_size in (0, None) and batch_size is None and self.fe_collate_fn == default_collate:
            # The user did not provide a batched dataset nor a batch size, so default collate won't work. Have to try
            # convert instead.
            self.fe_collate_fn = default_convert
        self.fe_postprocess_fn = postprocess_fn

        # We could disable pre-collating when num_workers=0, but this would lead to inconsistent batch ordering between
        # single- and multi-processing.

        super().__init__(
            dataset=dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=num_workers,
            persistent_workers=False,
            collate_fn=functools.partial(_pre_collate, try_fn=self.fe_collate_fn, postprocess_fn=postprocess_fn),
            worker_init_fn=lambda _: np.random.seed(random.randint(0, 2**32 - 1)))
        if self.batch_size is not None:
            # We need a special fetcher type later in order to build batches correctly
            self._dataset_kind = self.FE_LOADER_KIND

    def shutdown(self) -> None:
        """Close the worker threads used by this iterator.

        The hope is that this will prevent "RuntimeError: DataLoader worker (pid(s) XXXX) exited unexpectedly" during
        the test suites.
        """
        if isinstance(self._iterator, _MultiProcessingDataLoaderIter):
            self._iterator._shutdown_workers()
        self._iterator = None
        FEDataLoader._current_threads.clear()

    def __iter__(self) -> _BaseDataLoaderIter:
        # Similar to the original iter method, but we remember iterators in order to manually close them when new ones
        # are created
        self.shutdown()
        self._iterator = self._get_fe_iterator()
        if isinstance(self._iterator, _MultiProcessingDataLoaderIter):
            FEDataLoader._current_threads.extend([w.pid for w in self._iterator._workers])
        return self._iterator

    def _get_fe_iterator(self):
        if self.num_workers == 0:
            if self.batch_size is None:
                # We use 'fake' batch size here to identify datasets which perform their own batching
                return _SPPostBatchIter(self)
            return _SPPreBatchIter(self)
        else:
            with Suppressor(allow_pyprint=True):  # Prevent unnecessary warnings about resetting numbers of threads
                if self.batch_size is None:
                    # We use 'fake' batch size here to identify datasets which perform their own batching
                    return _MPPostBatchIter(self)
                return _MPPreBatchIter(self)

    def __len__(self):
        return self.fe_samples_to_yield

    def get_batch_size(self) -> int:
        return self.fe_batch_size

shutdown

Close the worker threads used by this iterator.

The hope is that this will prevent "RuntimeError: DataLoader worker (pid(s) XXXX) exited unexpectedly" during the test suites.

Source code in fastestimator/fastestimator/dataset/dataloader.py
def shutdown(self) -> None:
    """Close the worker threads used by this iterator.

    The hope is that this will prevent "RuntimeError: DataLoader worker (pid(s) XXXX) exited unexpectedly" during
    the test suites.
    """
    if isinstance(self._iterator, _MultiProcessingDataLoaderIter):
        self._iterator._shutdown_workers()
    self._iterator = None
    FEDataLoader._current_threads.clear()

InfiniteSampler

Bases: Sampler

A class which never stops sampling.

Parameters:

Name Type Description Default
data_source Sized

The dataset to be sampled.

required
shuffle bool

Whether to shuffle when sampling.

True
reset_fn Optional[Callable[[bool], None]]

A function to be invoked (using the provided shuffle arg) every time the dataset has been fully traversed.

None
convert_fn Optional[Callable[[int], Any]]

A function to be invoked (using the current index) every sample in order to convert an integer index into some arbitrary alternative index representation.

None
Source code in fastestimator/fastestimator/dataset/dataloader.py
class InfiniteSampler(Sampler):
    """A class which never stops sampling.

    Args:
        data_source: The dataset to be sampled.
        shuffle: Whether to shuffle when sampling.
        reset_fn: A function to be invoked (using the provided `shuffle` arg) every time the dataset has been fully
            traversed.
        convert_fn: A function to be invoked (using the current index) every sample in order to convert an integer index
            into some arbitrary alternative index representation.
    """
    def __init__(self,
                 data_source: Sized,
                 shuffle: bool = True,
                 reset_fn: Optional[Callable[[bool], None]] = None,
                 convert_fn: Optional[Callable[[int], Any]] = None):
        super().__init__(data_source=None)  # Arg is unused and triggers a warning in torch 2.1
        self.interleave_ds = isinstance(data_source, InterleaveDataset)
        self.ds_len = len(data_source)
        if self.ds_len < 1:
            raise ValueError("dataset length must be at least 1")
        self.indices = [i for i in range(self.ds_len)]
        self.shuffle = shuffle
        self.reset_fn = reset_fn
        self.convert_fn = convert_fn
        self.idx = 0

    def __len__(self):
        return self.ds_len

    def __iter__(self):
        self.idx = 0
        if self.reset_fn:
            self.reset_fn(self.shuffle)
        if self.shuffle and not self.interleave_ds:
            # interleave_ds requires unshuffled indices to work correctly with its repeating pattern
            random.shuffle(self.indices)
        return self

    def __next__(self):
        if self.idx == self.ds_len:
            self.idx = 0
            if self.reset_fn:
                self.reset_fn(self.shuffle)
            if self.shuffle and not self.interleave_ds:
                # interleave_ds requires unshuffled indices to work correctly with its repeating pattern
                random.shuffle(self.indices)
        elem = self.indices[self.idx]
        self.idx += 1
        if self.convert_fn:
            elem = self.convert_fn(elem)
        return elem