Source code for torchio.datasets.rsna_spine_fracture
from __future__ import annotations
from pathlib import Path
from typing import Any
from typing import Union
from ..data import LabelMap
from ..data import ScalarImage
from ..data import Subject
from ..data import SubjectsDataset
from ..external.imports import get_pandas
from ..types import TypePath
from ..utils import normalize_path
TypeBoxes = list[dict[str, Union[str, float, int]]]
[docs]
class RSNACervicalSpineFracture(SubjectsDataset):
"""RSNA 2022 Cervical Spine Fracture Detection dataset.
This is a helper class for the dataset used in the
`RSNA 2022 Cervical Spine Fracture Detection`_ hosted on
`kaggle <https://www.kaggle.com/>`_. The dataset must be downloaded before
instantiating this class.
.. _RSNA 2022 Cervical Spine Fracture Detection: https://www.kaggle.com/competitions/rsna-2022-cervical-spine-fracture-detection/overview/evaluation
"""
UID = 'StudyInstanceUID'
def __init__(
self,
root_dir: TypePath,
add_segmentations: bool = False,
add_bounding_boxes: bool = False,
**kwargs,
):
self.root_dir = normalize_path(root_dir)
subjects = self._get_subjects(
add_segmentations,
add_bounding_boxes,
)
super().__init__(subjects, **kwargs)
@staticmethod
def _get_image_dirs_dict(images_dir: Path) -> dict[str, Path]:
dirs_dict = {}
for dicom_dir in sorted(images_dir.iterdir()):
dirs_dict[dicom_dir.name] = dicom_dir
return dirs_dict
@staticmethod
def _get_segs_paths_dict(segs_dir: Path) -> dict[str, Path]:
paths_dict = {}
for image_path in sorted(segs_dir.iterdir()):
key = image_path.name.replace('.gz', '').replace('.nii', '')
paths_dict[key] = image_path
return paths_dict
def _get_subjects(
self,
add_segmentations: bool,
add_bounding_boxes: bool,
) -> list[Subject]:
subjects = []
pd = get_pandas()
from tqdm.auto import tqdm
split_name = 'train'
images_dirname = f'{split_name}_images'
images_dir = self.root_dir / images_dirname
image_dirs_dict = self._get_image_dirs_dict(images_dir)
segmentations_dir = self.root_dir / 'segmentations'
seg_paths_dict = self._get_segs_paths_dict(segmentations_dir)
bboxes_path = self.root_dir / 'train_bounding_boxes.csv'
bounding_boxes_df = pd.read_csv(bboxes_path)
grouped_boxes = bounding_boxes_df.groupby(self.UID)
df = pd.read_csv(self.root_dir / f'{split_name}.csv')
for _, row in tqdm(list(df.iterrows())):
uid = row[self.UID]
image_dir = image_dirs_dict[uid]
seg_path = None
if add_segmentations:
seg_path = seg_paths_dict.get(uid, None)
boxes = []
if add_bounding_boxes:
try:
boxes_df = grouped_boxes.get_group(uid)
boxes = [dict(row) for _, row in boxes_df.iterrows()]
except KeyError:
pass
subject = self._get_subject(
dict(row),
image_dir,
seg_path,
boxes,
)
subjects.append(subject)
return subjects
@staticmethod
def _filter_list(iterable: list[Path], target: str):
def _filter(path: Path):
if path.is_dir():
return target == path.name
else:
name = path.name.replace('.gz', '').replace('.nii', '')
return target == name
found = list(filter(_filter, iterable))
if found:
assert len(found) == 1
result = found[0]
else:
result = None
return result
def _get_subject(
self,
csv_row_dict: dict[str, str | int],
image_dir: Path,
seg_path: Path | None,
boxes: TypeBoxes,
) -> Subject:
subject_dict: dict[str, Any] = {}
subject_dict.update(csv_row_dict)
subject_dict['ct'] = ScalarImage(image_dir)
if seg_path is not None:
subject_dict['seg'] = LabelMap(seg_path)
if boxes:
subject_dict['boxes'] = boxes
return Subject(**subject_dict)