Source code for torchio.transforms.preprocessing.spatial.resample

from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Sized
from numbers import Number
from pathlib import Path
from typing import Union

import numpy as np
import SimpleITK as sitk
import torch

from ....data.image import Image
from ....data.image import ScalarImage
from ....data.io import get_sitk_metadata_from_ras_affine
from ....data.io import sitk_to_nib
from ....data.subject import Subject
from ....types import TypePath
from ....types import TypeTripletFloat
from ...spatial_transform import SpatialTransform

TypeSpacing = Union[float, tuple[float, float, float]]
TypeTarget = Union[TypeSpacing, str, Path, Image, None]
ONE_MILLIMITER_ISOTROPIC = 1


[docs] class Resample(SpatialTransform): """Resample image to a different physical space. This is a powerful transform that can be used to change the image shape or spatial metadata, or to apply a spatial transformation. Args: target: Argument to define the output space. Can be one of: - Output spacing :math:`(s_w, s_h, s_d)`, in mm. If only one value :math:`s` is specified, then :math:`s_w = s_h = s_d = s`. - Path to an image that will be used as reference. - Instance of :class:`~torchio.Image`. - Name of an image key in the subject. - Tuple ``(spatial_shape, affine)`` defining the output space. pre_affine_name: Name of the *image key* (not subject key) storing an affine matrix that will be applied to the image header before resampling. If ``None``, the image is resampled with an identity transform. See usage in the example below. image_interpolation: See :ref:`Interpolation`. label_interpolation: See :ref:`Interpolation`. scalars_only: Apply only to instances of :class:`~torchio.ScalarImage`. Used internally by :class:`~torchio.transforms.RandomAnisotropy`. antialias: If ``True``, apply Gaussian smoothing before downsampling along any dimension that will be downsampled. For example, if the input image has spacing (0.5, 0.5, 4) and the target spacing is (1, 1, 1), the image will be smoothed along the first two dimensions before resampling. Label maps are not smoothed. The standard deviations of the Gaussian kernels are computed according to the method described in Cardoso et al., `Scale factor point spread function matching: beyond aliasing in image resampling <https://link.springer.com/chapter/10.1007/978-3-319-24571-3_81>`_, MICCAI 2015. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. Example: >>> import torch >>> import torchio as tio >>> transform = tio.Resample() # resample all images to 1mm isotropic >>> transform = tio.Resample(2) # resample all images to 2mm isotropic >>> transform = tio.Resample('t1') # resample all images to 't1' image space >>> # Example: using a precomputed transform to MNI space >>> ref_path = tio.datasets.Colin27().t1.path # this image is in the MNI space, so we can use it as reference/target >>> affine_matrix = tio.io.read_matrix('transform_to_mni.txt') # from a NiftyReg registration. Would also work with e.g. .tfm from SimpleITK >>> image = tio.ScalarImage(tensor=torch.rand(1, 256, 256, 180), to_mni=affine_matrix) # 'to_mni' is an arbitrary name >>> transform = tio.Resample(colin.t1.path, pre_affine_name='to_mni') # nearest neighbor interpolation is used for label maps >>> transformed = transform(image) # "image" is now in the MNI space .. note:: The ``antialias`` option is recommended when large (e.g. > 2×) downsampling factors are expected, particularly for offline (before training) preprocessing, when run times are not a concern. .. plot:: import torchio as tio subject = tio.datasets.FPG() subject.remove_image('seg') resample = tio.Resample(8) t1_resampled = resample(subject.t1) subject.add_image(t1_resampled, 'Antialias off') resample = tio.Resample(8, antialias=True) t1_resampled_antialias = resample(subject.t1) subject.add_image(t1_resampled_antialias, 'Antialias on') subject.plot() """ def __init__( self, target: TypeTarget = ONE_MILLIMITER_ISOTROPIC, image_interpolation: str = 'linear', label_interpolation: str = 'nearest', pre_affine_name: str | None = None, scalars_only: bool = False, antialias: bool = False, **kwargs, ): super().__init__(**kwargs) self.target = target self.image_interpolation = self.parse_interpolation( image_interpolation, ) self.label_interpolation = self.parse_interpolation( label_interpolation, ) self.pre_affine_name = pre_affine_name self.scalars_only = scalars_only self.antialias = antialias self.args_names = [ 'target', 'image_interpolation', 'label_interpolation', 'pre_affine_name', 'scalars_only', 'antialias', ] @staticmethod def _parse_spacing(spacing: TypeSpacing) -> tuple[float, float, float]: result: Iterable if isinstance(spacing, Iterable) and len(spacing) == 3: result = spacing elif isinstance(spacing, Number): result = 3 * (spacing,) else: message = ( 'Target must be a string, a positive number' f' or a sequence of positive numbers, not {type(spacing)}' ) raise ValueError(message) if np.any(np.array(spacing) <= 0): message = f'Spacing must be strictly positive, not "{spacing}"' raise ValueError(message) return result @staticmethod def check_affine(affine_name: str, image: Image): if not isinstance(affine_name, str): message = f'Affine name argument must be a string, not {type(affine_name)}' raise TypeError(message) if affine_name in image: matrix = image[affine_name] if not isinstance(matrix, (np.ndarray, torch.Tensor)): message = ( 'The affine matrix must be a NumPy array or PyTorch' f' tensor, not {type(matrix)}' ) raise TypeError(message) if matrix.shape != (4, 4): message = f'The affine matrix shape must be (4, 4), not {matrix.shape}' raise ValueError(message) @staticmethod def check_affine_key_presence(affine_name: str, subject: Subject): for image in subject.get_images(intensity_only=False): if affine_name in image: return message = ( f'An affine name was given ("{affine_name}"), but it was not found' ' in any image in the subject' ) raise ValueError(message) def apply_transform(self, subject: Subject) -> Subject: use_pre_affine = self.pre_affine_name is not None if use_pre_affine: assert self.pre_affine_name is not None # for mypy self.check_affine_key_presence(self.pre_affine_name, subject) for image in self.get_images(subject): # If the current image is the reference, don't resample it if self.target is image: continue # If the target is not a string, or is not an image in the subject, # do nothing try: target_image = subject[self.target] if target_image is image: continue except (KeyError, TypeError, RuntimeError): pass # Choose interpolation if not isinstance(image, ScalarImage): if self.scalars_only: continue interpolation = self.label_interpolation else: interpolation = self.image_interpolation interpolator = self.get_sitk_interpolator(interpolation) # Apply given affine matrix if found in image if use_pre_affine and self.pre_affine_name in image: assert self.pre_affine_name is not None # for mypy self.check_affine(self.pre_affine_name, image) matrix = image[self.pre_affine_name] if isinstance(matrix, torch.Tensor): matrix = matrix.numpy() image.affine = matrix @ image.affine floating_sitk = image.as_sitk(force_3d=True) resampler = self._get_resampler( interpolator, floating_sitk, subject, self.target, ) if self.antialias and isinstance(image, ScalarImage): downsampling_factor = self._get_downsampling_factor( floating_sitk, resampler, ) sigmas = self._get_sigmas( downsampling_factor, floating_sitk.GetSpacing(), ) floating_sitk = self._smooth(floating_sitk, sigmas) resampled = resampler.Execute(floating_sitk) array, affine = sitk_to_nib(resampled) image.set_data(torch.as_tensor(array)) image.affine = affine return subject @staticmethod def _smooth( image: sitk.Image, sigmas: np.ndarray, epsilon: float = 1e-9, ) -> sitk.Image: """Smooth the image with a Gaussian kernel. Args: image: Image to be smoothed. sigmas: Standard deviations of the Gaussian kernel for each dimension. If a value is NaN, no smoothing is applied in that dimension. epsilon: Small value to replace NaN values in sigmas, to avoid division-by-zero errors. """ sigmas[np.isnan(sigmas)] = epsilon # no smoothing in that dimension gaussian = sitk.SmoothingRecursiveGaussianImageFilter() gaussian.SetSigma(sigmas.tolist()) smoothed = gaussian.Execute(image) return smoothed @staticmethod def _get_downsampling_factor( floating: sitk.Image, resampler: sitk.ResampleImageFilter, ) -> np.ndarray: """Get the downsampling factor for each dimension. The downsampling factor is the ratio between the output spacing and the input spacing. If the output spacing is smaller than the input spacing, the factor is set to NaN, meaning downsampling is not applied in that dimension. Args: floating: The input image to be resampled. resampler: The resampler that will be used to resample the image. """ input_spacing = np.array(floating.GetSpacing()) output_spacing = np.array(resampler.GetOutputSpacing()) factors = output_spacing / input_spacing no_downsampling = factors <= 1 factors[no_downsampling] = np.nan return factors def _get_resampler( self, interpolator: int, floating: sitk.Image, subject: Subject, target: TypeTarget, ) -> sitk.ResampleImageFilter: """Instantiate a SimpleITK resampler.""" resampler = sitk.ResampleImageFilter() resampler.SetInterpolator(interpolator) self._set_resampler_reference( resampler, target, # type: ignore[arg-type] floating, subject, ) return resampler def _set_resampler_reference( self, resampler: sitk.ResampleImageFilter, target: TypeSpacing | TypePath | Image, floating_sitk, subject, ): # Target can be: # 1) An instance of torchio.Image # 2) An instance of pathlib.Path # 3) A string, which could be a path or an image in subject # 4) A number or sequence of numbers for spacing # 5) A tuple of shape, affine # The fourth case is the different one if isinstance(target, (str, Path, Image)): if isinstance(target, Image): # It's a TorchIO image image = target elif Path(target).is_file(): # It's an existing file path = target image = ScalarImage(path) else: # assume it's the name of an image in the subject try: image = subject[target] except KeyError as error: message = ( f'Image name "{target}" not found in subject.' f' If "{target}" is a path, it does not exist or' ' permission has been denied' ) raise ValueError(message) from error self._set_resampler_from_shape_affine( resampler, image.spatial_shape, image.affine, ) elif isinstance(target, Number): # one number for target was passed self._set_resampler_from_spacing(resampler, target, floating_sitk) elif isinstance(target, Iterable) and len(target) == 2: assert not isinstance(target, str) # for mypy shape, affine = target if not (isinstance(shape, Sized) and len(shape) == 3): message = ( 'Target shape must be a sequence of three integers, but' f' "{shape}" was passed' ) raise RuntimeError(message) if not affine.shape == (4, 4): message = ( 'Target affine must have shape (4, 4) but the following' f' was passed:\n{shape}' ) raise RuntimeError(message) self._set_resampler_from_shape_affine( resampler, shape, affine, ) elif isinstance(target, Iterable) and len(target) == 3: self._set_resampler_from_spacing(resampler, target, floating_sitk) else: raise RuntimeError(f'Target not understood: "{target}"') def _set_resampler_from_shape_affine(self, resampler, shape, affine): origin, spacing, direction = get_sitk_metadata_from_ras_affine(affine) resampler.SetOutputDirection(direction) resampler.SetOutputOrigin(origin) resampler.SetOutputSpacing(spacing) resampler.SetSize(shape) def _set_resampler_from_spacing(self, resampler, target, floating_sitk): target_spacing = self._parse_spacing(target) reference_image = self.get_reference_image( floating_sitk, target_spacing, ) resampler.SetReferenceImage(reference_image) @staticmethod def get_reference_image( floating_sitk: sitk.Image, spacing: TypeTripletFloat, ) -> sitk.Image: old_spacing = np.array(floating_sitk.GetSpacing(), dtype=float) new_spacing = np.array(spacing, dtype=float) old_size = np.array(floating_sitk.GetSize()) old_last_index = old_size - 1 old_last_index_lps = np.array( floating_sitk.TransformIndexToPhysicalPoint(old_last_index.tolist()), dtype=float, ) old_origin_lps = np.array(floating_sitk.GetOrigin(), dtype=float) center_lps = (old_last_index_lps + old_origin_lps) / 2 # We use floor to avoid extrapolation by keeping the extent of the # new image the same or smaller than the original. new_size = np.floor(old_size * old_spacing / new_spacing) # We keep singleton dimensions to avoid e.g. making 2D images 3D new_size[old_size == 1] = 1 direction = np.asarray(floating_sitk.GetDirection(), dtype=float).reshape(3, 3) half_extent = (new_size - 1) / 2 * new_spacing new_origin_lps = (center_lps - direction @ half_extent).tolist() reference = sitk.Image( new_size.astype(int).tolist(), floating_sitk.GetPixelID(), floating_sitk.GetNumberOfComponentsPerPixel(), ) reference.SetDirection(floating_sitk.GetDirection()) reference.SetSpacing(new_spacing.tolist()) reference.SetOrigin(new_origin_lps) return reference @staticmethod def _get_sigmas(downsampling_factor: np.ndarray, spacing: np.ndarray) -> np.ndarray: """Compute optimal standard deviation for Gaussian kernel. From Cardoso et al., `Scale factor point spread function matching: beyond aliasing in image resampling <https://link.springer.com/chapter/10.1007/978-3-319-24571-3_81>`_, MICCAI 2015. Args: downsampling_factor: Array with the downsampling factor for each dimension. spacing: Array with the spacing of the input image in mm. """ k = downsampling_factor # Equation from top of page 678 of proceedings (4/9 in the PDF) variance = (k**2 - 1) * (2 * np.sqrt(2 * np.log(2))) ** (-2) sigma = spacing * np.sqrt(variance) return sigma