Skip to content

EnsureShapeMultiple

EnsureShapeMultiple

Bases: SpatialTransform

Ensure that all values in the image shape are divisible by \(n\).

Some convolutional neural network architectures need that the size of the input across all spatial dimensions is a power of \(2\).

For example, the canonical 3D U-Net from Çiçek et al. includes three downsampling (pooling) and upsampling operations:

3D U-Net

Pooling operations in PyTorch round down the output size:

>>> import torch
>>> x = torch.rand(3, 10, 20, 31)
>>> x_down = torch.nn.functional.max_pool3d(x, 2)
>>> x_down.shape
torch.Size([3, 5, 10, 15])

If we upsample this tensor, the original shape is lost:

>>> x_down_up = torch.nn.functional.interpolate(x_down, scale_factor=2)
>>> x_down_up.shape
torch.Size([3, 10, 20, 30])
>>> x.shape
torch.Size([3, 10, 20, 31])

If we try to concatenate x_down and x_down_up (to create skip connections), we will get an error. It is therefore good practice to ensure that the size of our images is such that concatenations will be safe.

Note

In these examples, it's assumed that all convolutions in the U-Net use padding so that the output size is the same as the input size.

The image above shows \(3\) downsampling operations, so the input size along all dimensions should be a multiple of \(2^3 = 8\).

Example (assuming pip install unet has been run before):

>>> import torchio as tio
>>> import unet
>>> net = unet.UNet3D(padding=1)
>>> t1 = tio.datasets.Colin27().t1
>>> tensor_bad = t1.data.unsqueeze(0)
>>> tensor_bad.shape
torch.Size([1, 1, 181, 217, 181])
>>> net(tensor_bad).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/unet.py", line 122, in forward
    x = self.decoder(skip_connections, encoding)
  File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/decoding.py", line 61, in forward
    x = decoding_block(skip_connection, x)
  File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/fernando/miniconda3/envs/resseg/lib/python3.7/site-packages/unet/decoding.py", line 131, in forward
    x = torch.cat((skip_connection, x), dim=CHANNELS_DIMENSION)
RuntimeError: Sizes of tensors must match except in dimension 1. Got 45 and 44 in dimension 2 (The offending index is 1)
>>> num_poolings = 3
>>> fix_shape_unet = tio.EnsureShapeMultiple(2**num_poolings)
>>> t1_fixed = fix_shape_unet(t1)
>>> tensor_ok = t1_fixed.data.unsqueeze(0)
>>> tensor_ok.shape
torch.Size([1, 1, 184, 224, 184])  # as expected

Parameters:

Name Type Description Default
target_multiple int | TypeTripletInt

Tuple \((n_w, n_h, n_d)\), so that the size of the output along axis \(i\) is a multiple of \(n_i\). If a single value \(n\) is provided, then \(n_w = n_h = n_d = n\).

required
method str

Either 'crop' or 'pad'.

'pad'
**kwargs

See Transform for additional keyword arguments.

{}

Examples:

>>> import torchio as tio
>>> image = tio.datasets.Colin27().t1
>>> image.shape
(1, 181, 217, 181)
>>> transform = tio.EnsureShapeMultiple(8, method='pad')
>>> transformed = transform(image)
>>> transformed.shape
(1, 184, 224, 184)
>>> transform = tio.EnsureShapeMultiple(8, method='crop')
>>> transformed = transform(image)
>>> transformed.shape
(1, 176, 216, 176)
>>> image_2d = image.data[..., :1]
>>> image_2d.shape
torch.Size([1, 181, 217, 1])
>>> transformed = transform(image_2d)
>>> transformed.shape
torch.Size([1, 176, 216, 1])

__call__(data)

Transform data and return a result of the same type.

Parameters:

Name Type Description Default
data InputType

Instance of torchio.Subject, 4D torch.Tensor or numpy.ndarray with dimensions \((C, W, H, D)\), where \(C\) is the number of channels and \(W, H, D\) are the spatial dimensions. If the input is a tensor, the affine matrix will be set to identity. Other valid input types are a SimpleITK image, a torchio.Image, a NiBabel Nifti1 image or a dict. The output type is the same as the input type.

required

get_base_args()

Provides easy access to the arguments used to instantiate the base class (Transform) of any transform.

This method is particularly useful when a new transform can be represented as a variant of an existing transform (e.g. all random transforms), allowing for seamless instantiation of the existing transform with the same arguments as the new transform during apply_transform.

Note

The p argument (probability of applying the transform) is excluded to avoid multiplying the probability of both existing and new transform.

add_base_args(arguments, overwrite_on_existing=False)

Add the init args to existing arguments

validate_keys_sequence(keys, name) staticmethod

Ensure that the input is not a string but a sequence of strings.

to_hydra_config()

Return a dictionary representation of the transform for Hydra instantiation.