Source code for torchio.transforms.augmentation.intensity.random_swap

from __future__ import annotations

from collections import defaultdict
from collections.abc import Sequence
from typing import TypeVar

import numpy as np
import torch

from ....data.subject import Subject
from ....types import TypeTripletInt
from ....types import TypeTuple
from ....utils import to_tuple
from ...intensity_transform import IntensityTransform
from .. import RandomTransform

TypeLocations = Sequence[tuple[TypeTripletInt, TypeTripletInt]]
TensorArray = TypeVar('TensorArray', np.ndarray, torch.Tensor)


[docs] class RandomSwap(RandomTransform, IntensityTransform): r"""Randomly swap patches within an image. This is typically used in `context restoration for self-supervised learning <https://www.sciencedirect.com/science/article/pii/S1361841518304699>`_. Args: patch_size: Tuple of integers :math:`(w, h, d)` to swap patches of size :math:`w \times h \times d`. If a single number :math:`n` is provided, :math:`w = h = d = n`. num_iterations: Number of times that two patches will be swapped. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ def __init__( self, patch_size: TypeTuple = 15, num_iterations: int = 100, **kwargs, ): super().__init__(**kwargs) self.patch_size = np.array(to_tuple(patch_size)) self.num_iterations = self._parse_num_iterations(num_iterations) @staticmethod def _parse_num_iterations(num_iterations): if not isinstance(num_iterations, int): raise TypeError( f'num_iterations must be an int,not {num_iterations}', ) if num_iterations < 0: raise ValueError( f'num_iterations must be positive,not {num_iterations}', ) return num_iterations @staticmethod def get_params( tensor: torch.Tensor, patch_size: np.ndarray, num_iterations: int, ) -> list[tuple[TypeTripletInt, TypeTripletInt]]: si, sj, sk = tensor.shape[-3:] spatial_shape = si, sj, sk # for mypy locations = [] for _ in range(num_iterations): first_ini, first_fin = get_random_indices_from_shape( spatial_shape, patch_size.tolist(), # type: ignore[arg-type] ) while True: second_ini, second_fin = get_random_indices_from_shape( spatial_shape, patch_size.tolist(), # type: ignore[arg-type] ) larger_than_initial = np.all(second_ini >= first_ini) less_than_final = np.all(second_fin <= first_fin) if larger_than_initial and less_than_final: continue # patches overlap else: break # patches don't overlap location = tuple(first_ini), tuple(second_ini) locations.append(location) return locations # type: ignore[return-value] def apply_transform(self, subject: Subject) -> Subject: images_dict = self.get_images_dict(subject) if not images_dict: return subject arguments: dict[str, dict] = defaultdict(dict) for name, image in images_dict.items(): locations = self.get_params( image.data, self.patch_size, self.num_iterations, ) arguments['locations'][name] = locations arguments['patch_size'][name] = self.patch_size transform = Swap(**self.add_base_args(arguments)) transformed = transform(subject) assert isinstance(transformed, Subject) return transformed
class Swap(IntensityTransform): r"""Swap patches within an image. This is typically used in `context restoration for self-supervised learning <https://www.sciencedirect.com/science/article/pii/S1361841518304699>`_. Args: patch_size: Tuple of integers :math:`(w, h, d)` to swap patches of size :math:`w \times h \times d`. If a single number :math:`n` is provided, :math:`w = h = d = n`. num_iterations: Number of times that two patches will be swapped. **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments. """ def __init__( self, patch_size: TypeTripletInt | dict[str, TypeTripletInt], locations: TypeLocations | dict[str, TypeLocations], **kwargs, ): super().__init__(**kwargs) self.locations = locations self.patch_size = patch_size self.args_names = ['locations', 'patch_size'] self.invert_transform = False def apply_transform(self, subject: Subject) -> Subject: locations, patch_size = self.locations, self.patch_size for name, image in self.get_images_dict(subject).items(): if self.arguments_are_dict(): assert isinstance(self.locations, dict) assert isinstance(self.patch_size, dict) locations = self.locations[name] patch_size = self.patch_size[name] if self.invert_transform: assert isinstance(locations, list) locations.reverse() swapped = _swap(image.data, patch_size, locations) # type: ignore[arg-type] image.set_data(swapped) return subject def _swap( tensor: torch.Tensor, patch_size: TypeTuple, locations: list[tuple[np.ndarray, np.ndarray]], ) -> torch.Tensor: # Note this function modifies the input in-place tensor = tensor.clone() patch_size_array = np.array(patch_size) for first_ini, second_ini in locations: first_fin = first_ini + patch_size_array second_fin = second_ini + patch_size_array first_patch = _crop(tensor, first_ini, first_fin) second_patch = _crop(tensor, second_ini, second_fin).clone() _insert(tensor, first_patch, second_ini) _insert(tensor, second_patch, first_ini) return tensor def _insert( tensor: TensorArray, patch: TensorArray, index_ini: np.ndarray, ) -> None: index_fin = index_ini + np.array(patch.shape[-3:]) i_ini, j_ini, k_ini = index_ini i_fin, j_fin, k_fin = index_fin tensor[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] = patch def _crop( image: TensorArray, index_ini: np.ndarray, index_fin: np.ndarray, ) -> TensorArray: i_ini, j_ini, k_ini = index_ini i_fin, j_fin, k_fin = index_fin return image[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin] def get_random_indices_from_shape( spatial_shape: Sequence[int], patch_size: Sequence[int], ) -> tuple[np.ndarray, np.ndarray]: assert len(spatial_shape) == 3 assert len(patch_size) in (1, 3) shape_array = np.array(spatial_shape) patch_size_array = np.array(patch_size) max_index_ini_unchecked = shape_array - patch_size_array if (max_index_ini_unchecked < 0).any(): message = ( f'Patch size {patch_size} cannot be' f' larger than image spatial shape {spatial_shape}' ) raise ValueError(message) max_index_ini = max_index_ini_unchecked.astype(np.uint16) coordinates = [] for max_coordinate in max_index_ini.tolist(): if max_coordinate == 0: coordinate = 0 else: coordinate = int(torch.randint(max_coordinate, size=(1,)).item()) coordinates.append(coordinate) index_ini = np.array(coordinates, np.uint16) index_fin = index_ini + patch_size_array return index_ini, index_fin