Skip to content

medmnist

load_data

Download and load the medmnist data. Medmnist has 18 datasets. Here is the list of datasets available in Medmnist. [ 'chestmnist', 'adrenalmnist3d', 'bloodmnist', 'breastmnist', 'dermamnist', 'fracturemnist3d', 'nodulemnist3d', 'octmnist', 'organamnist', 'organcmnist', 'organmnist3d', 'organsmnist', 'pathmnist', 'pneumoniamnist', 'retinamnist', 'synapsemnist3d', 'tissuemnist', 'vesselmnist3d' ] For more details on the dataset, please check https://medmnist.com, https://zenodo.org/record/6496656

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset to download and load.

required
root_dir str

The path to store the downloaded data. When path is not provided, the data will be saved into fastestimator_data under the user's home directory. Defaults to None.

None
image_key str

The key for image. Defaults to "x".

'x'
label_key str

The key for label. Defaults to "y".

'y'

Raises:

Type Description
ValueError

if the dataset_name is invalid

Returns:

Type Description
Tuple[NumpyDataset, NumpyDataset, NumpyDataset]

Tuple[NumpyDataset, NumpyDataset, NumpyDataset]: returns a tuple of traing, val and test data.

Source code in fastestimator/fastestimator/dataset/data/medmnist.py
def load_data(
    dataset_name: str,
    root_dir: str = None,
    image_key: str = "x",
    label_key: str = "y",
) -> Tuple[NumpyDataset, NumpyDataset, NumpyDataset]:
    """
    Download and load the medmnist data. Medmnist has 18 datasets. Here is the list of datasets available in Medmnist.
    [
        'chestmnist',
        'adrenalmnist3d',
        'bloodmnist',
        'breastmnist',
        'dermamnist',
        'fracturemnist3d',
        'nodulemnist3d',
        'octmnist',
        'organamnist',
        'organcmnist',
        'organmnist3d',
        'organsmnist',
        'pathmnist',
        'pneumoniamnist',
        'retinamnist',
        'synapsemnist3d',
        'tissuemnist',
        'vesselmnist3d'
    ]
    For more details on the dataset, please check https://medmnist.com, https://zenodo.org/record/6496656

    Args:
        dataset_name (str): Name of the dataset to download and load.
        root_dir (str, optional): The path to store the downloaded data. When `path` is not provided, the data will be saved into
            `fastestimator_data` under the user's home directory. Defaults to None.
        image_key (str, optional): The key for image. Defaults to "x".
        label_key (str, optional): The key for label. Defaults to "y".

    Raises:
        ValueError: if the dataset_name is invalid

    Returns:
        Tuple[NumpyDataset, NumpyDataset, NumpyDataset]: returns a tuple of traing, val and test data.
    """
    if dataset_name not in dataset_ids:
        raise ValueError("Invalid value for dataset_name.")

    if root_dir is None:
        home = str(Path.home())
        root_dir = os.path.join(home, "fastestimator_data", "medmnist")
    else:
        root_dir = os.path.join(os.path.abspath(root_dir), "medmnist")

    os.makedirs(root_dir, exist_ok=True)

    download_path = os.path.join(root_dir, f"{dataset_name}.npz")

    print("Downloading data to {}".format(root_dir))
    download_file_from_google_drive(dataset_ids[dataset_name], download_path)

    npz_file = np.load(download_path)

    x_train = npz_file["train_images"]
    y_train = npz_file["train_labels"]

    x_val = npz_file["val_images"]
    y_val = npz_file["val_labels"]

    x_test = npz_file["test_images"]
    y_test = npz_file["test_labels"]

    train_data = NumpyDataset({image_key: x_train, label_key: y_train})
    val_data = NumpyDataset({image_key: x_val, label_key: y_val})
    test_data = NumpyDataset({image_key: x_test, label_key: y_test})
    return train_data, val_data, test_data