Skip to content

combined_dataset

CombinedDataset

Bases: ConcatDataset

Combines a list of PyTorch Datasets.

Parameters:

Name Type Description Default
datasets List[Dataset]

Pytorch (or FE) Datasets to be combined.

required

Raises:

Type Description
AssertionError

raise exception when the input list has less than 2 datasets.

KeyError

raise exception when the datasets does not have same keys.

Source code in fastestimator/fastestimator/dataset/combined_dataset.py
@traceable()
class CombinedDataset(ConcatDataset):
    """Combines a list of PyTorch Datasets.

    Args:
        datasets: Pytorch (or FE) Datasets to be combined.

    Raises:
        AssertionError: raise exception when the input list has less than 2 datasets.
        KeyError: raise exception when the datasets does not have same keys.
    """
    def __init__(self, datasets: List[Dataset]) -> None:
        super().__init__(datasets)
        keys = None

        for ds in datasets:
            if isinstance(ds, InterleaveDataset):
                raise AssertionError("CombinedDataset does not support InterleaveDataset")
            if isinstance(ds, Dataset) and isinstance(ds[0], dict):
                if keys is None:
                    keys = ds[0].keys()
                elif ds[0].keys() != keys:
                    raise KeyError("All datasets should have the same keys.")
            else:
                raise AssertionError("Each dataset should be a type of PyTorch Dataset and should return a dictionary.")