Skip to content

mendeley

load_data

Load and return the Mendeley dataset.

Kermany, Daniel; Zhang, Kang; Goldbaum, Michael (2018), "Labeled Optical Coherence Tomography (OCT) and Chest X-Ray Images for Classification", Mendeley Data, v2 http://dx.doi.org/10.17632/rscbjbr9sj.2

CC BY 4.0 licence: https://creativecommons.org/licenses/by/4.0/

Parameters:

Name Type Description Default
root_dir Optional[str]

The path to store the downloaded data. When path is not provided, the data will be saved into fastestimator_data under the user's home directory.

None

Returns:

Type Description
Tuple[LabeledDirDataset, LabeledDirDataset]

(train_data, test_data)

Source code in fastestimator/fastestimator/dataset/data/mendeley.py
def load_data(root_dir: Optional[str] = None) -> Tuple[LabeledDirDataset, LabeledDirDataset]:
    """Load and return the Mendeley dataset.

    Kermany, Daniel; Zhang, Kang; Goldbaum, Michael (2018), "Labeled Optical Coherence Tomography (OCT) and Chest X-Ray
    Images for Classification", Mendeley Data, v2 http://dx.doi.org/10.17632/rscbjbr9sj.2

    CC BY 4.0 licence:
    https://creativecommons.org/licenses/by/4.0/

    Args:
        root_dir: The path to store the downloaded data. When `path` is not provided, the data will be saved into
            `fastestimator_data` under the user's home directory.

    Returns:
        (train_data, test_data)
    """
    url = 'https://data.mendeley.com/public-files/datasets/rscbjbr9sj/files/f12eaf6d-6023-432f-acc9-80c9d7393433/' \
          'file_downloaded'
    user_agent = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/' \
                 '70.0.3538.77 Safari/537.36'
    home = str(Path.home())

    if root_dir is None:
        root_dir = os.path.join(home, 'fastestimator_data', 'Mendeley')
    else:
        root_dir = os.path.join(os.path.abspath(root_dir), 'Mendeley')
    os.makedirs(root_dir, exist_ok=True)

    data_compressed_path = os.path.join(root_dir, 'ChestXRay2017.zip')
    data_folder_path = os.path.join(root_dir, 'chest_xray')

    if not os.path.exists(data_folder_path):
        # download
        if not os.path.exists(data_compressed_path):
            print("Downloading data to {}".format(root_dir))
            stream = requests.get(url, stream=True, headers={'User-Agent':user_agent})  # python wget does not work
            total_size = int(stream.headers.get('content-length', 0))
            block_size = int(1e6)  # 1 MB
            progress = tqdm(total=total_size, unit='B', unit_scale=True)
            with open(data_compressed_path, 'wb') as outfile:
                for data in stream.iter_content(block_size):
                    progress.update(len(data))
                    outfile.write(data)
            progress.close()

        # extract
        print("\nExtracting file ...")
        with zipfile.ZipFile(data_compressed_path, 'r') as zip_file:
            # There's some garbage data from macOS in the zip file that gets filtered out here
            zip_file.extractall(root_dir, filter(lambda x: x.startswith("chest_xray/"), zip_file.namelist()))

    label_mapping = {'NORMAL': 0, 'PNEUMONIA': 1}
    return LabeledDirDataset(os.path.join(data_folder_path, "train"), label_mapping=label_mapping,
                             file_extension=".jpeg"), LabeledDirDataset(os.path.join(data_folder_path, "test"),
                                                                        label_mapping=label_mapping,
                                                                        file_extension=".jpeg")