Source code for torchio.datasets.ct_rate

from __future__ import annotations

import enum
import multiprocessing
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Literal
from typing import Union

from tqdm.contrib.concurrent import thread_map

from ..data.dataset import SubjectsDataset
from ..data.image import ScalarImage
from ..data.subject import Subject
from ..external.imports import get_pandas
from ..types import TypePath

if TYPE_CHECKING:
    import pandas as pd


TypeSplit = Union[
    Literal['train'],
    Literal['valid'],
    Literal['validation'],
]

TypeParallelism = Literal['thread', 'process', None]


class MetadataIndexColumn(str, enum.Enum):
    SUBJECT_ID = 'subject_id'
    SCAN_ID = 'scan_id'
    RECONSTRUCTION_ID = 'reconstruction_id'


[docs] class CtRate(SubjectsDataset): """CT-RATE dataset. This class helps loading the `CT-RATE dataset <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE>`_, which contains chest CT scans with associated radiology reports and abnormality labels. The dataset must have been downloaded previously. Args: root: Root directory where the dataset has been downloaded. split: Dataset split to use, either ``'train'`` or ``'validation'``. num_subjects: Optional limit on the number of subjects to load (useful for debugging). If ``None``, all subjects in the split are loaded. report_key: Key to use for storing radiology reports in the Subject metadata. sizes: List of image sizes (in-plane, in voxels) to include. load_fixed: If ``True``, load the files with fixed spatial metadata added in `this pull request <https://huggingface.co/datasets/ibrahimhamamci/CT-RATE/discussions/85>`_. Otherwise, load the original files with incorrect spatial metadata. verify_paths: If ``True``, verify that the paths to the images exist during instantiation of the dataset. This might be slow for large that are not stored locally. **kwargs: Additional arguments for SubjectsDataset. Examples: >>> from torchio.datasets import CtRate >>> dataset = CtRate('/path/to/CT-RATE', sizes=[512]) """ _REPO_ID = 'ibrahimhamamci/CT-RATE' _FILENAME_KEY = 'VolumeName' _SIZES = [512, 768, 1024] ABNORMALITIES = [ 'Medical material', 'Arterial wall calcification', 'Cardiomegaly', 'Pericardial effusion', 'Coronary artery wall calcification', 'Hiatal hernia', 'Lymphadenopathy', 'Emphysema', 'Atelectasis', 'Lung nodule', 'Lung opacity', 'Pulmonary fibrotic sequela', 'Pleural effusion', 'Mosaic attenuation pattern', 'Peribronchial thickening', 'Consolidation', 'Bronchiectasis', 'Interlobular septal thickening', ] REPORT_KEYS = [ 'ClinicalInformation_EN', 'Findings_EN', 'Impressions_EN', 'Technique_EN', ] def __init__( self, root: TypePath, split: TypeSplit = 'train', *, num_subjects: int | None = None, report_key: str = 'report', sizes: list[int] | None = None, load_fixed: bool = True, verify_paths: bool = False, **kwargs, ): self._root_dir = Path(root) self._num_subjects = num_subjects self._report_key = report_key self._sizes = self._SIZES if sizes is None else sizes self._split = self._parse_split(split) self.metadata = self._get_metadata() self._load_fixed = load_fixed self._verify_paths = verify_paths subjects_list = self._get_subjects_list(self.metadata) super().__init__(subjects_list, **kwargs) @staticmethod def _parse_split(split: str) -> str: """Normalize the split name. Converts 'validation' to 'valid' and validates that the split name is one of the allowed values. Args: split: The split name to parse ('train', 'valid', or 'validation'). Returns: str: Normalized split name ('train' or 'valid'). Raises: ValueError: If the split name is not one of the allowed values. """ if split in ['valid', 'validation']: return 'valid' if split not in ['train', 'valid']: raise ValueError(f"Invalid split '{split}'. Use 'train' or 'valid'") return split def _get_csv( self, dirname: str, filename: str, ) -> pd.DataFrame: """Load a CSV file from the specified directory within the dataset. Args: dirname: Directory name within 'dataset/' where the CSV is located. filename: Name of the CSV file to load. """ subfolder = Path(f'dataset/{dirname}') path = Path(self._root_dir, subfolder, filename) pd = get_pandas() table = pd.read_csv(path) return table def _get_csv_prefix(self, expand_validation: bool = True) -> str: """Get the prefix for CSV filenames based on the current split. Returns the appropriate prefix for CSV filenames based on the current split. For the validation split, can either return 'valid' or 'validation' depending on the expand_validation parameter. Args: expand_validation: If ``True`` and split is ``'valid'``, return ``'validation'``. Otherwise, return the split name as is. """ if expand_validation and self._split == 'valid': prefix = 'validation' else: prefix = self._split return prefix def _get_metadata(self) -> pd.DataFrame: """Load and process the dataset metadata. Loads metadata from the appropriate CSV file, filters images by size, extracts subject, scan, and reconstruction IDs from filenames, and merges in reports and abnormality labels. """ dirname = 'metadata' prefix = self._get_csv_prefix() filename = f'{prefix}_metadata.csv' metadata = self._get_csv(dirname, filename) # Exclude images with size not in self._sizes rows_int = metadata['Rows'].astype(int) metadata = metadata[rows_int.isin(self._sizes)] index_columns = [ MetadataIndexColumn.SUBJECT_ID.value, MetadataIndexColumn.SCAN_ID.value, MetadataIndexColumn.RECONSTRUCTION_ID.value, ] pattern = r'\w+_(\d+)_(\w+)_(\d+)\.nii\.gz' metadata[index_columns] = metadata[self._FILENAME_KEY].str.extract(pattern) if self._num_subjects is not None: metadata = self._keep_n_subjects(metadata, self._num_subjects) # Add reports and abnormality labels to metadata, keeping only the rows for the # images in the metadata table metadata = self._merge(metadata, self._get_reports()) metadata = self._merge(metadata, self._get_labels()) metadata.set_index(index_columns, inplace=True) return metadata def _merge(self, base_df: pd.DataFrame, new_df: pd.DataFrame) -> pd.DataFrame: """Merge a new dataframe into the base dataframe using the filename as the key. This method performs a left join between ``base_df`` and ``new_df`` using the volume filename as the join key, ensuring that all records from ``base_df`` are preserved while matching data from ``new_df`` is added. Args: base_df: The primary dataframe to merge into. new_df: The dataframe containing additional data to be merged. Returns: pd.DataFrame: The merged dataframe with all rows from base_df and matching columns from new_df. """ pd = get_pandas() return pd.merge( base_df, new_df, on=self._FILENAME_KEY, how='left', ) def _keep_n_subjects(self, metadata: pd.DataFrame, n: int) -> pd.DataFrame: """Limit the metadata to the first ``n`` subjects. Args: metadata: The complete metadata dataframe. n: Maximum number of subjects to keep. """ unique_subjects = metadata['subject_id'].unique() selected_subjects = unique_subjects[:n] return metadata[metadata['subject_id'].isin(selected_subjects)] def _get_reports(self) -> pd.DataFrame: """Load the radiology reports associated with the CT scans. Retrieves the CSV file containing radiology reports for the current split (train or validation). """ dirname = 'radiology_text_reports' prefix = self._get_csv_prefix() filename = f'{prefix}_reports.csv' return self._get_csv(dirname, filename) def _get_labels(self) -> pd.DataFrame: """Load the abnormality labels for the CT scans. Retrieves the CSV file containing predicted abnormality labels for the current split. """ dirname = 'multi_abnormality_labels' prefix = self._get_csv_prefix(expand_validation=False) filename = f'{prefix}_predicted_labels.csv' return self._get_csv(dirname, filename) def _get_subjects_list(self, metadata: pd.DataFrame) -> list[Subject]: """Create a list of Subject instances from the metadata. Processes the metadata to create Subject objects, each containing one or more CT images. Processing is performed in parallel. Note: This method uses parallelization to improve performance when creating multiple Subject instances. """ df_no_index = metadata.reset_index() num_subjects = df_no_index['subject_id'].nunique() iterable = df_no_index.groupby('subject_id') subjects = thread_map( self._get_subject, iterable, max_workers=multiprocessing.cpu_count(), total=num_subjects, ) return subjects def _get_subject( self, subject_id_and_metadata: tuple[str, pd.DataFrame], ) -> Subject: """Create a Subject instance for a specific subject. Processes all images belonging to a single subject and creates a Subject object containing those images. Args: subject_id_and_metadata: A tuple containing the subject ID (string) and a DataFrame containing metadata for all images associated to that subject. """ subject_id, subject_df = subject_id_and_metadata subject_dict: dict[str, str | ScalarImage] = {'subject_id': subject_id} for _, image_row in subject_df.iterrows(): image = self._instantiate_image(image_row) scan_id = image_row['scan_id'] reconstruction_id = image_row['reconstruction_id'] image_key = f'scan_{scan_id}_reconstruction_{reconstruction_id}' subject_dict[image_key] = image return Subject(**subject_dict) # type: ignore[arg-type] def _instantiate_image(self, image_row: pd.Series) -> ScalarImage: """Create a ScalarImage object for a specific image. Processes a row from the metadata DataFrame to create a ScalarImage object, Args: image_row: A pandas Series representing a row from the metadata DataFrame, containing information about a single image. """ image_dict = image_row.to_dict() filename = image_dict[self._FILENAME_KEY] relative_image_path = self._get_image_path( filename, load_fixed=self._load_fixed, ) image_path = self._root_dir / relative_image_path report_dict = self._extract_report_dict(image_dict) image_dict[self._report_key] = report_dict image = ScalarImage(image_path, verify_path=self._verify_paths, **image_dict) return image def _extract_report_dict(self, subject_dict: dict[str, str]) -> dict[str, str]: """Extract radiology report information from the subject dictionary. Extracts the English radiology report components (clinical information, findings, impressions, and technique) from the subject dictionary and removes these keys from the original dictionary. Args: subject_dict: Image metadata including report fields. Note: This method modifies the input subject_dict by removing the report keys. """ report_dict = {} for key in self.REPORT_KEYS: report_dict[key] = subject_dict.pop(key) return report_dict @staticmethod def _get_image_path(filename: str, load_fixed: bool) -> Path: """Construct the relative path to an image file within the dataset structure. Parses the filename to determine the hierarchical directory structure where the image is stored in the CT-RATE dataset. Args: filename: The name of the image file (e.g., 'train_2_a_1.nii.gz'). Returns: Path: The relative path to the image file within the dataset directory. Example: >>> path = CtRate._get_image_path('train_2_a_1.nii.gz') # Returns Path('dataset/train/train_2/train_2_a/train_2_a_1.nii.gz') """ parts = filename.split('_') base_dir = 'dataset' split_dir = parts[0] if load_fixed: split_dir = f'{split_dir}_fixed' level1 = f'{parts[0]}_{parts[1]}' level2 = f'{level1}_{parts[2]}' return Path(base_dir, split_dir, level1, level2, filename)