Source code for torchio.transforms.augmentation.spatial.random_flip
import numpy as np
import torch
from ....data.subject import Subject
from ....utils import to_tuple
from ...spatial_transform import SpatialTransform
from .. import RandomTransform
[docs]
class RandomFlip(RandomTransform, SpatialTransform):
"""Reverse the order of elements in an image along the given axes.
Args:
axes: Index or tuple of indices of the spatial dimensions along which
the image might be flipped. If they are integers, they must be in
``(0, 1, 2)``. Anatomical labels may also be used, such as
``'Left'``, ``'Right'``, ``'Anterior'``, ``'Posterior'``,
``'Inferior'``, ``'Superior'``, ``'Height'`` and ``'Width'``,
``'AP'`` (antero-posterior), ``'lr'`` (lateral), ``'w'`` (width) or
``'i'`` (inferior). Only the first letter of the string will be
used. If the image is 2D, ``'Height'`` and ``'Width'`` may be
used.
flip_probability: Probability that the image will be flipped. This is
computed on a per-axis basis.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
Example:
>>> import torchio as tio
>>> fpg = tio.datasets.FPG()
>>> flip = tio.RandomFlip(axes=('LR',)) # flip along lateral axis only
.. tip:: It is handy to specify the axes as anatomical labels when the
image orientation is not known.
"""
def __init__(
self,
axes: int | tuple[int, ...] = 0,
flip_probability: float = 0.5,
**kwargs,
):
super().__init__(**kwargs)
self.axes = _parse_axes(axes)
self.flip_probability = self.parse_probability(flip_probability)
def apply_transform(self, subject: Subject) -> Subject:
potential_axes = _ensure_axes_indices(subject, self.axes)
axes_to_flip_hot = self.get_params(self.flip_probability)
for i in range(3):
if i not in potential_axes:
axes_to_flip_hot[i] = False
(axes,) = np.where(axes_to_flip_hot)
axes_list = axes.tolist()
if not axes_list:
return subject
arguments = {'axes': axes_list}
transform = Flip(**self.add_base_args(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
@staticmethod
def get_params(probability: float) -> list[bool]:
return (probability > torch.rand(3)).tolist()
class Flip(SpatialTransform):
"""Reverse the order of elements in an image along the given axes.
Args:
axes: Index or tuple of indices of the spatial dimensions along which
the image will be flipped. See
:class:`~torchio.transforms.augmentation.spatial.random_flip.RandomFlip`
for more information.
**kwargs: See :class:`~torchio.transforms.Transform` for additional
keyword arguments.
.. tip:: It is handy to specify the axes as anatomical labels when the
image orientation is not known.
"""
def __init__(self, axes, **kwargs):
super().__init__(**kwargs)
self.axes = _parse_axes(axes)
self.args_names = ['axes']
def apply_transform(self, subject: Subject) -> Subject:
axes = _ensure_axes_indices(subject, self.axes)
for image in self.get_images(subject):
_flip_image(image, axes)
return subject
def is_invertible(self):
return True
def inverse(self):
return self
def _parse_axes(axes: int | tuple[int, ...]):
axes_tuple = to_tuple(axes)
for axis in axes_tuple:
is_int = isinstance(axis, int)
is_string = isinstance(axis, str)
valid_number = is_int and axis in (0, 1, 2)
if not is_string and not valid_number:
message = (
f'All axes must be 0, 1 or 2, but found "{axis}" with type {type(axis)}'
)
raise ValueError(message)
return axes_tuple
def _ensure_axes_indices(subject, axes):
if any(isinstance(n, str) for n in axes):
subject.check_consistent_orientation()
image = subject.get_first_image()
axes = sorted(3 + image.axis_name_to_index(n) for n in axes)
return axes
def _flip_image(image, axes):
spatial_axes = np.array(axes, int) + 1
data = image.numpy()
data = np.flip(data, axis=spatial_axes)
data = np.ascontiguousarray(data) # remove negative strides
data = torch.as_tensor(data)
image.set_data(data)