Inference
Here is an example that uses a grid sampler and aggregator to perform dense inference across a 3D image using patches:
>>> import torch
>>> import torch.nn as nn
>>> import torchio as tio
>>> patch_overlap = 4, 4, 4 # or just 4
>>> patch_size = 88, 88, 60
>>> subject = tio.datasets.Colin27()
>>> subject
Colin27(Keys: ('t1', 'head', 'brain'); images: 3)
>>> grid_sampler = tio.inference.GridSampler(
... subject,
... patch_size,
... patch_overlap,
... )
>>> patch_loader = tio.SubjectsLoader(grid_sampler, batch_size=4)
>>> aggregator = tio.inference.GridAggregator(grid_sampler)
>>> model = nn.Identity().eval()
>>> with torch.no_grad():
... for patches_batch in patch_loader:
... input_tensor = patches_batch['t1'][tio.DATA]
... locations = patches_batch[tio.LOCATION]
... logits = model(input_tensor)
... labels = logits.argmax(dim=tio.CHANNELS_DIMENSION, keepdim=True)
... outputs = labels
... aggregator.add_batch(outputs, locations)
>>> output_tensor = aggregator.get_output_tensor()
Grid sampler
GridSampler
GridSampler
Bases: PatchSampler
Extract patches across a whole volume.
Grid samplers are useful to perform inference using all patches from a
volume. It is often used with a GridAggregator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
subject
|
Subject
|
Instance of |
required |
patch_size
|
TypeSpatialShape
|
Tuple of integers \((w, h, d)\) to generate patches of size \(w \times h \times d\). If a single number \(n\) is provided, \(w = h = d = n\). |
required |
patch_overlap
|
TypeSpatialShape
|
Tuple of even integers \((w_o, h_o, d_o)\) specifying the overlap between patches for dense inference. If a single number \(n\) is provided, \(w_o = h_o = d_o = n\). |
(0, 0, 0)
|
padding_mode
|
str | float | None
|
Same as |
None
|
Examples:
>>> import torchio as tio
>>> colin = tio.datasets.Colin27()
>>> sampler = tio.GridSampler(colin, patch_size=88)
>>> for i, patch in enumerate(sampler()):
... patch.t1.save(f'patch_{i}.nii.gz')
...
>>> # To figure out the number of patches beforehand:
>>> sampler = tio.GridSampler(colin, patch_size=88)
>>> len(sampler)
8
Note
Adapted from NiftyNet. See this NiftyNet tutorial
for more
information about patch based sampling. Note that
patch_overlap is twice border in NiftyNet
tutorial.
Grid aggregator
GridAggregator
GridAggregator
Aggregate patches for dense inference.
This class is typically used to build a volume made of patches after
inference of batches extracted by a GridSampler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sampler
|
GridSampler
|
Instance of |
required |
overlap_mode
|
str
|
If |
'crop'
|
downsampling_factor
|
int
|
Factor by which the output volume is expected to be smaller than the input volume in each spatial dimension. This is useful when the model downsamples the input (e.g., with strided convolutions or pooling layers). Currently, only a single integer is supported, which applies the same downsampling factor to all spatial dimensions. |
1
|
Note
Adapted from NiftyNet. See this NiftyNet tutorial for more information about patch-based sampling.
add_batch(batch_tensor, locations)
Add batch processed by a network to the output prediction volume.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_tensor
|
Tensor
|
5D tensor, typically the output of a convolutional
neural network, e.g. |
required |
locations
|
Tensor
|
2D tensor with shape \((B, 6)\) representing the
patch indices in the original image. They are typically
extracted using |
required |
get_output_tensor()
Get the aggregated volume after dense inference.