Skip to content

numpy_dataset

NumpyDataset

Bases: InMemoryDataset

A dataset constructed from a dictionary of Numpy data or list of data.

Parameters:

Name Type Description Default
data Dict[str, Union[ndarray, List]]

A dictionary of data like {"key1": , "key2": [list]}.

required

Raises: AssertionError: If any of the Numpy arrays or lists have differing numbers of elements. ValueError: If any dictionary value is not instance of Numpy array or list.

Source code in fastestimator/fastestimator/dataset/numpy_dataset.py
@traceable()
class NumpyDataset(InMemoryDataset):
    """A dataset constructed from a dictionary of Numpy data or list of data.

    Args:
        data: A dictionary of data like {"key1": <numpy array>, "key2": [list]}.
    Raises:
        AssertionError: If any of the Numpy arrays or lists have differing numbers of elements.
        ValueError: If any dictionary value is not instance of Numpy array or list.
    """
    def __init__(self, data: Dict[str, Union[np.ndarray, List]]) -> None:
        size = None
        for val in data.values():
            if isinstance(val, np.ndarray):
                current_size = val.shape[0]
            elif isinstance(val, list):
                current_size = len(val)
            else:
                raise ValueError("Please ensure you are passing numpy array or list in the data dictionary.")
            if size is not None:
                assert size == current_size, "All data arrays must have the same number of elements"
            else:
                size = current_size
        super().__init__({i: {k: v[i] for k, v in data.items()} for i in range(size)} if size else {})