Source code for torchio.transforms.preprocessing.spatial.to_reference_space

import numpy as np
import torch

from ....data.image import Image
from ....data.subject import Subject
from ...spatial_transform import SpatialTransform
from .resample import Resample


[docs] class ToReferenceSpace(SpatialTransform): """Modify the spatial metadata so it matches a reference space. This is useful, for example, to set meaningful spatial metadata of a neural network embedding, for visualization or further processing such as resampling a segmentation output. Example: >>> import torchio as tio >>> image = tio.datasets.FPG().t1 >>> embedding_tensor = my_network(image.tensor) # we lose metadata here >>> embedding_image = tio.ToReferenceSpace.from_tensor(embedding_tensor, image) """ def __init__(self, reference: Image, **kwargs): super().__init__(**kwargs) if not isinstance(reference, Image): raise TypeError('The reference must be a TorchIO image') self.reference = reference def apply_transform(self, subject: Subject) -> Subject: for image in self.get_images(subject): new_image = build_image_from_reference(image.data, self.reference) image.set_data(new_image.data) image.affine = new_image.affine return subject @staticmethod def from_tensor(tensor: torch.Tensor, reference: Image) -> Image: """Build a TorchIO image from a tensor and a reference image.""" return build_image_from_reference(tensor, reference)
def build_image_from_reference(tensor: torch.Tensor, reference: Image) -> Image: input_shape = np.array(reference.spatial_shape) output_shape = np.array(tensor.shape[-3:]) downsampling_factor = input_shape / output_shape input_spacing = np.array(reference.spacing) output_spacing = input_spacing * downsampling_factor downsample = Resample(output_spacing, image_interpolation='nearest') reference = downsample(reference) class_ = reference.__class__ result = class_(tensor=tensor, affine=reference.affine) return result