Source code for torchio.transforms.preprocessing.spatial.pad

import warnings
from numbers import Number
from typing import Union

import numpy as np
import torch
from nibabel.affines import apply_affine

from ....data.image import Image
from ....data.subject import Subject
from .bounds_transform import BoundsTransform
from .bounds_transform import TypeBounds


[docs] class Pad(BoundsTransform): r"""Pad an image. Args: padding: Tuple :math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})` defining the number of values padded to the edges of each axis. If the initial shape of the image is :math:`W \times H \times D`, the final shape will be :math:`(w_{ini} + W + w_{fin}) \times (h_{ini} + H + h_{fin}) \times (d_{ini} + D + d_{fin})`. If only three values :math:`(w, h, d)` are provided, then :math:`w_{ini} = w_{fin} = w`, :math:`h_{ini} = h_{fin} = h` and :math:`d_{ini} = d_{fin} = d`. If only one value :math:`n` is provided, then :math:`w_{ini} = w_{fin} = h_{ini} = h_{fin} = d_{ini} = d_{fin} = n`. padding_mode: See possible modes in `NumPy docs`_. If it is a number, the mode will be set to ``'constant'``. If it is ``'mean'``, ``'maximum'``, ``'median'`` or ``'minimum'``, the statistic will be computed from the whole volume, unlike in NumPy, which computes it along the padded axis. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. .. seealso:: If you want to pass the output shape instead, please use :class:`~torchio.transforms.CropOrPad` instead. .. _NumPy docs: https://numpy.org/doc/stable/reference/generated/numpy.pad.html """ PADDING_MODES = ( 'empty', 'edge', 'wrap', 'constant', 'linear_ramp', 'maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric', ) def __init__( self, padding: TypeBounds, padding_mode: Union[str, float] = 0, **kwargs, ): super().__init__(padding, **kwargs) self.padding = padding self.check_padding_mode(padding_mode) self.padding_mode = padding_mode self.args_names = ['padding', 'padding_mode'] @classmethod def check_padding_mode(cls, padding_mode): is_number = isinstance(padding_mode, Number) is_callable = callable(padding_mode) if not (padding_mode in cls.PADDING_MODES or is_number or is_callable): message = ( f'Padding mode "{padding_mode}" not valid. Valid options are' f' {list(cls.PADDING_MODES)}, a number or a function' ) raise KeyError(message) def _check_truncation(self, image: Image, mode: Union[str, float]) -> None: if mode not in ('mean', 'median'): return if torch.is_floating_point(image.data): return message = ( f'The constant value computed for padding mode "{mode}" might ' ' be truncated in the output, as the input image is not' 'floating point. Consider converting the image to a floating' ' point type before applying this transform.' ) warnings.warn(message, RuntimeWarning, stacklevel=2) def apply_transform(self, subject: Subject) -> Subject: assert self.bounds_parameters is not None low = self.bounds_parameters[::2] for image in self.get_images(subject): self._check_truncation(image, self.padding_mode) new_origin = apply_affine(image.affine, -np.array(low)) new_affine = image.affine.copy() new_affine[:3, 3] = new_origin mode: str | float = 'constant' constant: torch.Tensor | float | None = None kwargs: dict[str, str | float | torch.Tensor] = {} if isinstance(self.padding_mode, Number): constant = self.padding_mode # type: ignore[assignment] elif self.padding_mode == 'maximum': constant = image.data.max() elif self.padding_mode == 'mean': constant = image.data.float().mean() elif self.padding_mode == 'median': constant = torch.quantile(image.data.float(), 0.5) elif self.padding_mode == 'minimum': constant = image.data.min() else: constant = None mode = self.padding_mode if constant is not None: kwargs['constant_values'] = constant kwargs['mode'] = mode pad_params = self.bounds_parameters paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:] padded = np.pad(image.data, paddings, **kwargs) # type: ignore[call-overload] image.set_data(torch.as_tensor(padded)) image.affine = new_affine return subject def inverse(self): from .crop import Crop return Crop(self.padding)