Source code for torchio.data.dataset
from __future__ import annotations
import copy
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Sequence
from torch.utils.data import Dataset
from ..utils import get_subjects_from_batch
from .subject import Subject
[docs]
class SubjectsDataset(Dataset):
"""Base TorchIO dataset.
Reader of 3D medical images that directly inherits from the PyTorch
:class:`~torch.utils.data.Dataset`. It can be used with a
:class:`~tio.SubjectsLoader` for efficient loading and
augmentation. It receives a list of instances of :class:`~torchio.Subject`
and an optional transform applied to the volumes after loading.
Args:
subjects: List of instances of :class:`~torchio.Subject`.
transform: An instance of :class:`~torchio.transforms.Transform`
that will be applied to each subject.
load_getitem: Load all subject images before returning it in
:meth:`__getitem__`. Set it to ``False`` if some of the images will
not be needed during training.
Example:
>>> import torchio as tio
>>> subject_a = tio.Subject(
... t1=tio.ScalarImage('t1.nrrd',),
... t2=tio.ScalarImage('t2.mha',),
... label=tio.LabelMap('t1_seg.nii.gz'),
... age=31,
... name='Fernando Perez',
... )
>>> subject_b = tio.Subject(
... t1=tio.ScalarImage('colin27_t1_tal_lin.minc',),
... t2=tio.ScalarImage('colin27_t2_tal_lin_dicom',),
... label=tio.LabelMap('colin27_seg1.nii.gz'),
... age=56,
... name='Colin Holmes',
... )
>>> subjects_list = [subject_a, subject_b]
>>> transforms = [
... tio.RescaleIntensity(out_min_max=(0, 1)),
... tio.RandomAffine(),
... ]
>>> transform = tio.Compose(transforms)
>>> subjects_dataset = tio.SubjectsDataset(subjects_list, transform=transform)
>>> subject = subjects_dataset[0]
.. _NiBabel: https://nipy.org/nibabel/#nibabel
.. _SimpleITK: https://itk.org/Wiki/ITK/FAQ#What_3D_file_formats_can_ITK_import_and_export.3F
.. _DICOM: https://www.dicomstandard.org/
.. _affine matrix: https://nipy.org/nibabel/coordinate_systems.html
.. tip:: To quickly iterate over the subjects without loading the images,
use :meth:`dry_iter()`.
"""
def __init__(
self,
subjects: Sequence[Subject],
transform: Callable | None = None,
load_getitem: bool = True,
):
self._parse_subjects_list(subjects)
self._subjects = subjects
self._transform: Callable | None
self.set_transform(transform)
self.load_getitem = load_getitem
def __len__(self):
return len(self._subjects)
def __getitem__(self, index: int) -> Subject:
try:
index = int(index)
except (RuntimeError, TypeError) as err:
message = (
f'Index "{index}" must be int or compatible dtype,'
f' but an object of type "{type(index)}" was passed'
)
raise ValueError(message) from err
subject = self._subjects[index]
subject = copy.deepcopy(subject) # cheap since images not loaded yet
if self.load_getitem:
subject.load()
# Apply transform (this is usually the bottleneck)
if self._transform is not None:
subject = self._transform(subject)
return subject
[docs]
@classmethod
def from_batch(cls, batch: dict) -> SubjectsDataset:
"""Instantiate a dataset from a batch generated by a data loader.
Args:
batch: Dictionary generated by a data loader, containing data that
can be converted to instances of :class:`~.torchio.Subject`.
"""
subjects: list[Subject] = get_subjects_from_batch(batch)
return cls(subjects)
[docs]
def dry_iter(self):
"""Return the internal list of subjects.
This can be used to iterate over the subjects without loading the data
and applying any transforms::
>>> names = [subject.name for subject in dataset.dry_iter()]
"""
return self._subjects
@staticmethod
def _parse_subjects_list(subjects_list: Iterable[Subject]) -> None:
# Check that it's an iterable
try:
iter(subjects_list)
except TypeError as e:
message = f'Subject list must be an iterable, not {type(subjects_list)}'
raise TypeError(message) from e
# Check that it's not empty
if not subjects_list:
raise ValueError('Subjects list is empty')
# Check each element
for subject in subjects_list:
if not isinstance(subject, Subject):
message = (
'Subjects list must contain instances of torchio.Subject,'
f' not "{type(subject)}"'
)
raise TypeError(message)