Source code for torchio.transforms.preprocessing.intensity.pca

from __future__ import annotations

from typing import Any

import numpy as np
from einops import rearrange

from ....data.image import ScalarImage
from ....data.subject import Subject
from ....external.imports import get_sklearn
from ...intensity_transform import IntensityTransform


[docs] class PCA(IntensityTransform): """Compute principal component analysis (PCA) of an image. PCA can be useful to visualize embeddings generated by a neural network. See for example Figure 8 in `Cluster and Predict Latent Patches for Improved Masked Image Modeling <https://arxiv.org/abs/2502.08769>`_. Args: num_components: Number of components to compute. keep_components: Number of components to keep in the output image. If ``None``, all components are kept. whiten: If ``True``, the components are normalized to have unit variance. normalize: If ``True``, all components are divided by the standard deviation of the first component. make_skewness_positive: If ``True``, the skewness of each component is made positive by multiplying the component by -1 if its skewness is negative. values_range: If not ``None``, these values are linearly mappped to :math:`[0, 1]`. clip: If ``True``, the output values are clipped to :math:`[0, 1]`. pca_kwargs: Additional keyword arguments to pass to :class:`sklearn.decomposition.PCA`. Example: >>> import torchio as tio >>> from torchio.visualization import build_image_from_reference >>> ct = my_preprocessed_ct_image # Assume this is a preprocessed CT image >>> ct ScalarImage(shape: (1, 240, 480, 480); spacing: (1.50, 0.75, 0.75); orientation: SLP+; dtype: torch.FloatTensor; memory: 210.9 MiB) >>> embedding_tensor = model(ct.data[None])[0] # `model` is some pre-trained neural network >>> embedding_image = ToReferenceSpace(ct)(embedding_tensor) >>> embedding_image ScalarImage(shape: (512, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 27.0 MiB) >>> pca = tio.PCA()(embedding_image) >>> pca ScalarImage(shape: (3, 24, 24, 24); spacing: (15.00, 15.00, 15.00); orientation: SLP+; dtype: torch.FloatTensor; memory: 162.0 KiB) """ def __init__( self, num_components: int = 6, *, keep_components: int | None = 3, whiten: bool = True, normalize: bool = True, make_skewness_positive: bool = True, values_range: tuple[float, float] | None = (-2.3, 2.3), clip: bool = True, pca_kwargs: dict[str, Any] | None = None, **kwargs, ): super().__init__(**kwargs) self.num_components = num_components self.keep_components = keep_components self.whiten = whiten self.normalize = normalize self.make_skewness_positive = make_skewness_positive self.values_range = values_range self.clip = clip self.pca_kwargs = pca_kwargs self.args_names = [ 'num_components', 'keep_components', 'whiten', 'normalize', 'make_skewness_positive', 'values_range', 'clip', 'pca_kwargs', ] def apply_transform(self, subject: Subject) -> Subject: for image in self.get_images(subject): kwargs = {} if self.pca_kwargs is None else self.pca_kwargs pca_image = _compute_pca( image, num_components=self.num_components, keep_components=self.keep_components, whiten=self.whiten, normalize=self.normalize, make_skewness_positive=self.make_skewness_positive, values_range=self.values_range, clip=self.clip, **kwargs, ) image.set_data(pca_image.data) return subject
def _compute_pca( embeddings: ScalarImage, num_components: int, keep_components: int | None, whiten: bool, normalize: bool, make_skewness_positive: bool, values_range: tuple[float, float] | None, clip: bool, **pca_kwargs, ) -> ScalarImage: # Adapted from https://github.com/facebookresearch/capi/blob/main/eval_visualizations.py # 2.3 is roughly 2σ for a standard-normal variable, 99% of values map inside [0,1]. sklearn = get_sklearn() PCA = sklearn.decomposition.PCA data = embeddings.numpy() _, size_x, size_y, size_z = data.shape X = rearrange(data, 'c x y z -> (x y z) c') pca = PCA(n_components=num_components, whiten=whiten, **pca_kwargs) projected: np.ndarray = pca.fit_transform(X).T if normalize: projected /= projected[0].std() if make_skewness_positive: for component in projected: third_cumulant = np.mean(component**3) second_cumulant = np.mean(component**2) skewness = third_cumulant / second_cumulant ** (3 / 2) if skewness < 0: component *= -1 grid: np.ndarray = rearrange( projected, 'c (x y z) -> c x y z', x=size_x, y=size_y, z=size_z, ) if values_range is not None: vmin, vmax = values_range else: vmin, vmax = grid.min(), grid.max() grid = (grid - vmin) / (vmax - vmin) if clip: grid = np.clip(grid, 0, 1) if keep_components is not None: grid = grid[:keep_components] return ScalarImage(tensor=grid, affine=embeddings.affine)