from __future__ import annotations
import warnings
import numpy as np
import torch
from ...constants import CHANNELS_DIMENSION
from ..sampler import GridSampler
[docs]
class GridAggregator:
r"""Aggregate patches for dense inference.
This class is typically used to build a volume made of patches after
inference of batches extracted by a :class:`~torchio.data.GridSampler`.
Args:
sampler: Instance of :class:`~torchio.data.GridSampler` used to
extract the patches.
overlap_mode: If ``'crop'``, the overlapping predictions will be
cropped. If ``'average'``, the predictions in the overlapping areas
will be averaged with equal weights. If ``'hann'``, the predictions
in the overlapping areas will be weighted with a Hann window
function. See the `grid aggregator tests`_ for a raw visualization
of the three modes.
downsampling_factor: Factor by which the output volume is expected to
be smaller than the input volume in each spatial dimension. This is
useful when the model downsamples the input (e.g., with strided
convolutions or pooling layers). Currently, only a single integer
is supported, which applies the same downsampling factor to all
spatial dimensions.
.. _grid aggregator tests: https://github.com/TorchIO-project/torchio/blob/main/tests/data/inference/test_aggregator.py
.. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
<https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
information about patch-based sampling.
"""
def __init__(
self,
sampler: GridSampler,
overlap_mode: str = 'crop',
downsampling_factor: int = 1, # TODO: support one per dimension
):
subject = sampler.subject
self.volume_padded = sampler.padding_mode is not None
self.spatial_shape = subject.spatial_shape
self._output_tensor: torch.Tensor | None = None
self.patch_overlap = sampler.patch_overlap
self.patch_size = sampler.patch_size
self._parse_overlap_mode(overlap_mode)
self.overlap_mode = overlap_mode
self._avgmask_tensor: torch.Tensor | None = None
self._hann_window: torch.Tensor | None = None
self._downsampling_factor = downsampling_factor
shape_array = np.array(subject.spatial_shape) // self._downsampling_factor
self.spatial_shape = tuple(shape_array.tolist())
@staticmethod
def _parse_overlap_mode(overlap_mode):
if overlap_mode not in ('crop', 'average', 'hann'):
message = (
'Overlap mode must be "crop", "average" or "hann" but '
f' "{overlap_mode}" was passed'
)
raise ValueError(message)
def _crop_patch(
self,
patch: torch.Tensor,
location: np.ndarray,
overlap: np.ndarray,
) -> tuple[torch.Tensor, np.ndarray]:
half_overlap = overlap // 2 # overlap is always even in grid sampler
index_ini, index_fin = location[:3], location[3:]
# If the patch is not at the border, we crop half the overlap
crop_ini: np.ndarray = half_overlap.copy()
crop_fin: np.ndarray = half_overlap.copy()
# If the volume has been padded, we don't need to worry about cropping
if self.volume_padded:
pass
else:
crop_ini *= index_ini > 0
crop_fin *= index_fin != self.spatial_shape
# Update the location of the patch in the volume
new_index_ini = index_ini + crop_ini
new_index_fin = index_fin - crop_fin
new_location = np.hstack((new_index_ini, new_index_fin))
patch_size = np.asarray(patch.shape[-3:], dtype=int)
crop_fin = crop_fin.astype(int)
i_ini, j_ini, k_ini = crop_ini
i_fin, j_fin, k_fin = patch_size - crop_fin
# Make type checkers happy
i_ini = int(i_ini)
j_ini = int(j_ini)
k_ini = int(k_ini)
i_fin = int(i_fin)
j_fin = int(j_fin)
k_fin = int(k_fin)
cropped_patch = patch[:, i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
return cropped_patch, new_location
def _initialize_output_tensor(self, batch: torch.Tensor) -> None:
if self._output_tensor is not None:
return
num_channels = batch.shape[CHANNELS_DIMENSION]
self._output_tensor = torch.zeros(
num_channels,
*self.spatial_shape,
dtype=batch.dtype,
)
def _initialize_avgmask_tensor(self, batch: torch.Tensor) -> None:
if self._avgmask_tensor is not None:
return
num_channels = batch.shape[CHANNELS_DIMENSION]
self._avgmask_tensor = torch.zeros(
num_channels,
*self.spatial_shape,
dtype=batch.dtype,
)
@staticmethod
def _get_hann_window(patch_size) -> torch.Tensor:
hann_window_3d = torch.as_tensor([1])
# create a n-dim hann window
for spatial_dim, size in enumerate(patch_size):
window_shape = np.ones_like(patch_size)
window_shape[spatial_dim] = size
hann_window_1d = torch.hann_window(
size + 2,
periodic=False,
)
hann_window_1d = hann_window_1d[1:-1].view(*window_shape)
hann_window_3d = hann_window_3d * hann_window_1d
return hann_window_3d
def _initialize_hann_window(self) -> None:
if self._hann_window is not None:
return
self._hann_window = self._get_hann_window(self.patch_size)
[docs]
def add_batch(
self,
batch_tensor: torch.Tensor,
locations: torch.Tensor,
) -> None:
"""Add batch processed by a network to the output prediction volume.
Args:
batch_tensor: 5D tensor, typically the output of a convolutional
neural network, e.g. ``batch['image'][torchio.DATA]``.
locations: 2D tensor with shape :math:`(B, 6)` representing the
patch indices in the original image. They are typically
extracted using ``batch[torchio.LOCATION]``.
"""
batch = batch_tensor.cpu()
locations_array = locations.cpu().numpy() // self._downsampling_factor
target_shapes = locations_array[:, 3:] - locations_array[:, :3]
# There should be only one patch size
assert len(np.unique(target_shapes, axis=0)) == 1
input_spatial_shape = tuple(batch.shape[-3:])
target_spatial_shape_array = target_shapes[0]
target_spatial_shape = tuple(target_spatial_shape_array.tolist())
if input_spatial_shape != target_spatial_shape:
message = (
f'The shape of the input batch, {input_spatial_shape},'
' does not match the shape of the target location,'
f' which is {target_spatial_shape}'
)
raise RuntimeError(message)
self._initialize_output_tensor(batch)
assert isinstance(self._output_tensor, torch.Tensor)
if self.overlap_mode == 'crop':
for patch, location in zip(batch, locations_array, strict=True):
cropped_patch, new_location = self._crop_patch(
patch,
location,
self.patch_overlap,
)
i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = new_location
self._output_tensor[
:,
i_ini:i_fin,
j_ini:j_fin,
k_ini:k_fin,
] = cropped_patch
elif self.overlap_mode == 'average':
self._initialize_avgmask_tensor(batch)
assert isinstance(self._avgmask_tensor, torch.Tensor)
for patch, location in zip(batch, locations, strict=True):
i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = location
self._output_tensor[
:,
i_ini:i_fin,
j_ini:j_fin,
k_ini:k_fin,
] += patch
self._avgmask_tensor[
:,
i_ini:i_fin,
j_ini:j_fin,
k_ini:k_fin,
] += 1
elif self.overlap_mode == 'hann':
# To handle edge and corners avoid numerical problems, we save the
# hann window in a different tensor
# At the end, it will be filled with ones (or close values) where
# there is overlap and < 1 where there is not
# When we divide, the multiplication will be canceled in areas that
# do not overlap
self._initialize_avgmask_tensor(batch)
self._initialize_hann_window()
if self._output_tensor.dtype != torch.float32:
self._output_tensor = self._output_tensor.float()
assert isinstance(self._avgmask_tensor, torch.Tensor) # for mypy
if self._avgmask_tensor.dtype != torch.float32:
self._avgmask_tensor = self._avgmask_tensor.float()
for patch, location in zip(batch, locations, strict=True):
i_ini, j_ini, k_ini, i_fin, j_fin, k_fin = location
patch = patch * self._hann_window
self._output_tensor[
:,
i_ini:i_fin,
j_ini:j_fin,
k_ini:k_fin,
] += patch
assert self._hann_window is not None
self._avgmask_tensor[
:,
i_ini:i_fin,
j_ini:j_fin,
k_ini:k_fin,
] += self._hann_window
[docs]
def get_output_tensor(self) -> torch.Tensor:
"""Get the aggregated volume after dense inference."""
assert isinstance(self._output_tensor, torch.Tensor)
if self._output_tensor.dtype == torch.int64:
message = (
'Medical image frameworks such as ITK do not support int64.'
' Casting to int32...'
)
warnings.warn(message, RuntimeWarning, stacklevel=2)
self._output_tensor = self._output_tensor.type(torch.int32)
if self.overlap_mode in ['average', 'hann']:
assert isinstance(self._avgmask_tensor, torch.Tensor) # for mypy
# true_divide is used instead of / in case the PyTorch version is
# old and one the operands is int:
# https://github.com/TorchIO-project/torchio/issues/526
output = torch.true_divide(
self._output_tensor,
self._avgmask_tensor,
)
else:
output = self._output_tensor
if self.volume_padded:
from ...transforms import Crop
border = self.patch_overlap // 2
cropping = border.repeat(2)
crop = Crop(cropping) # type: ignore[arg-type]
return crop(output) # type: ignore[return-value]
else:
return output