Source code for torchio.transforms.preprocessing.intensity.to
from __future__ import annotations
from typing import Any
import torch
from ....data.image import ScalarImage
from ....data.subject import Subject
from ...intensity_transform import IntensityTransform
[docs]
class To(IntensityTransform):
"""Convert the image tensor data type and/or device.
This transform is a thin wrapper around :func:`torch.Tensor.to`.
Args:
target: First argument to :func:`torch.Tensor.to`.
to_kwargs: Additional keyword arguments to pass to :func:`torch.Tensor.to`.
Example:
>>> import torchio as tio
>>> ct = tio.datasets.Slicer('CTChest').CT_chest
>>> clamp = tio.Clamp(out_min=-1000, out_max=1000)
>>> ct_clamped = clamp(ct)
>>> rescale = tio.RescaleIntensity(in_min_max=(-1000, 1000), out_min_max=(0, 255))
>>> ct_rescaled = rescale(ct_clamped)
>>> to_uint8 = tio.To(torch.uint8)
>>> ct_uint8 = to_uint8(ct_rescaled)
"""
def __init__(
self,
target: str | torch.dtype | torch.device,
to_kwargs: dict[str, Any] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.target = target
if to_kwargs is None:
to_kwargs = {}
self.to_kwargs = to_kwargs
self.args_names = ['target', 'to_kwargs']
def apply_transform(self, subject: Subject) -> Subject:
for image in self.get_images(subject):
assert isinstance(image, ScalarImage)
image.set_data(image.data.to(self.target, **self.to_kwargs))
return subject