# Estimation of covariance matrices
import warnings
import torch
import torch.nn as nn
from typing import Callable, Optional, Tuple
from tqdm import tqdm
from .functions import InvSqrtmEigFunction
[docs]
def normalize_trace(Sigma_batch: torch.Tensor) -> torch.Tensor:
"""Normalize covariance by the trace (trace is equal to n_features).
Parameters
----------
Sigma_batch: torch.Tensor
Batch of covariance matrices of shape (..., n_features, n_features),
where `...` are the batches dimensions
Returns
-------
torch.Tensor
Normalized batch of covariance matrices of shape
(..., n_features, n_features)
"""
traces = torch.einsum("...ii->...", Sigma_batch).unsqueeze(-1).unsqueeze(-1)
return Sigma_batch.shape[-2] * Sigma_batch / traces
[docs]
def normalize_determinant(Sigma_batch: torch.Tensor) -> torch.Tensor:
"""Normalize covariance by the determinant (determinant=1).
Parameters
----------
Sigma_batch: torch.Tensor
Batch of covariance matrices of shape (..., n_features, n_features),
where `...` are the batches dimensions
Returns
-------
torch.Tensor
Normalized batch of covariance matrices of shape
(..., n_features, n_features)
"""
det = torch.linalg.det(Sigma_batch).unsqueeze(-1).unsqueeze(-1)
return Sigma_batch / (torch.pow(det, 1 / Sigma_batch.shape[-2]))
[docs]
def student_function(x: torch.Tensor, n_features: int, nu: float) -> torch.Tensor:
"""Student function.
Parameters
----------
x : torch.Tensor
Input tensor.
n_features : int
Number of features.
nu : float
Degrees of freedom.
Returns
-------
torch.Tensor
Computed Student function over input tensor.
"""
return (n_features + nu) / (nu + x)
[docs]
def huber_function(x: torch.Tensor, delta: float, beta: float) -> torch.Tensor:
"""Huber function defined as:
* u(x) = 1/beta is x <= delta
* u(x) = delta/(beta*x) if x > delta
Parameters
----------
x : torch.Tensor
Input tensor.
delta : float
Threshold value.
beta : float
Scaling factor.
Returns
-------
torch.Tensor
Computed Huber function over input tensor.
"""
return torch.where(x <= delta, 1 / beta, delta / (beta * x))
[docs]
def tyler_function(x: torch.Tensor, n_features: int) -> torch.Tensor:
"""Tyler function.
Parameters
----------
x : torch.Tensor
Input tensor.
n_features : int
Number of features.
Returns
-------
torch.Tensor
Computed Tyler function over input tensor.
"""
return n_features / x
[docs]
class SCM(nn.Module):
"""Layer to compute SCM to estimate covariance matrix."""
def __init__(self, assume_centered: Optional[bool] = True) -> None:
super().__init__()
self.assume_centered = assume_centered
[docs]
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""Compute SCM over data.
Parameters
----------
X : torch.Tensor of shape (..., n_samples, n_features)
Input tensor batch.
Returns
-------
torch.Tensor of shape (..., n_features, n_features)
Estimated covariance matrices (one per batch).
"""
if self.assume_centered:
_X = X
n_samples = X.shape[-2]
else:
_X = X - X.mean(dim=-2, keepdim=True)
n_samples = X.shape[-2] - 1
Sigma = torch.einsum("...ij,...jk->...ik", _X.transpose(-2, -1), _X) / n_samples
return 0.5 * (Sigma + Sigma.transpose(-2, -1))
[docs]
class Mestimation(nn.Module):
"""Torch implementation of M-estimators of covariance matrix."""
[docs]
def __init__(
self,
m_estimation_function: Callable,
n_iter: int = 30,
tol: float = 1e-6,
verbose: bool = False,
assume_centered: bool = False,
normalize: Optional[Callable] = None,
) -> None:
"""
Initializes the M-estimation module.
Parameters
----------
m_estimation_function : Callable
The M-estimation function to use.
n_iter : int, optional (default=30)
The number of iterations to perform.
tol: float, optional (default=1e-6)
Tolerance for stopping criterion.
verbose : bool, optional (default=False)
Whether to display a progress bar during estimation.
assume_centered : bool, optional (default=False)
Whether to assume that the data is already centered.
normalize : Callable, optional (default=None)
A function to normalize the covariance matrix.
If None, no normalization will be performed.
"""
super().__init__()
self.m_estimation_function = m_estimation_function
self.n_iter = n_iter
self.tol = tol
self.verbose = verbose
self.assume_centered = assume_centered
self.normalize = normalize
def _init_pbar(self) -> None:
"""Initialize progress bar.
Parameters
----------
n_iter : int
Number of iterations.
"""
if self.verbose:
self.pbar = tqdm(total=self.n_iter, desc="M-estimation", leave=True)
def _update_pbar(self, delta: float) -> None:
"""Update progress bar."""
if self.verbose:
self.pbar.set_postfix({"delta": f"{delta:.2e}"})
self.pbar.update(1)
def _iter_fixed_point(
self, isqrtm_Sigma_prev: torch.Tensor, X: torch.Tensor, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""One iteration of fixed point algorithm for M-estimation.
Parameters
----------
isqrtm_Sigma_prev : torch.Tensor of shape (..., n_features, n_features)
Inverse square root of previous estimate of covariance matrices.
where `...` are the batches dimensions.
X : torch.Tensor of shape (..., n_samples, n_features)
Input tensor, where `...` are the batches dimensions.
**kwargs
Additional keyword arguments to pass to the M-estimation function.
Returns
-------
torch.Tensor of shape (..., n_features, n_features)
Updated estimate of covariance matrices.
torch.Tensor of shape (..., n_features, n_features)
Inverse sqrtm of updated estimate of covariance matrices.
"""
batches_dimensions = X.shape[:-2]
temp = torch.einsum(
"...ij,...jk->...ik", isqrtm_Sigma_prev, X.transpose(-2, -1)
)
quadratic = self.m_estimation_function(
torch.einsum("...ij,...ji->...i", temp.transpose(-2, -1), temp), **kwargs
)
temp = X.transpose(-2, -1) * torch.sqrt(
quadratic.unsqueeze(-2).repeat(
(1,) * len(batches_dimensions) + (X.shape[-1], 1)
)
)
Sigma = (
torch.einsum("...ij,...jk->...ik", temp, temp.transpose(-2, -1))
/ X.shape[-2]
)
isqrtm_Sigma = InvSqrtmEigFunction.apply(Sigma)
return Sigma, isqrtm_Sigma
[docs]
def forward(
self, X: torch.Tensor, init: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
"""Compute M-estimator of covariance matrix on a batch of data.
Parameters
----------
X : torch.Tensor of shape (..., n_samples, n_features)
Input tensor, where `...` are the batches dimensions.
init : torch.Tensor, optional (default=None)
The initial estimate of the covariance matrix.
If None, it will be initialized as the identity matrix.
The shape should be (n_features, n_features) in which case, it is
repeated over batches dimensions or it can be
(..., n_features, n_features), where `...` are the batches dimensions.
**kwargs
Additional keyword arguments to pass to the M-estimation function.
Returns
-------
torch.Tensor of shape (..., n_features, n_features)
Estimated covariance matrices (one per batch).
"""
batches_dimensions = X.shape[:-2]
if init is None:
Sigma = torch.eye(X.shape[-1], device=X.device)
for dim in reversed(batches_dimensions):
Sigma = Sigma.unsqueeze(0).repeat((dim,) + (1,) * Sigma.ndim)
isqrtm_Sigma = Sigma.clone()
else:
assert init.shape[-1] == X.shape[-1], (
f"Size of initial covariance ({init.shape}) "
+ f"incompatible with data ({X.shape})!"
)
if init.ndim > 2:
assert batches_dimensions == init.shape[:-2], (
f"Size of initial covariance ({init.shape}) "
+ f"incompatible with data ({X.shape})!"
)
Sigma = init
isqrtm_Sigma = InvSqrtmEigFunction.apply(Sigma)
if init.ndim == 2:
for dim in reversed(batches_dimensions):
Sigma = Sigma.unsqueeze(0).repeat((dim,) + (1,) * Sigma.ndim)
isqrtm_Sigma = isqrtm_Sigma.unsqueeze(0).repeat(
(dim,) + (1,) * isqrtm_Sigma.ndim
)
if not self.assume_centered:
X = X - X.mean(dim=-2, keepdim=True)
self._init_pbar()
for _ in range(self.n_iter):
Sigma_new, isqrtm_Sigma = self._iter_fixed_point(isqrtm_Sigma, X, **kwargs)
delta = torch.norm(Sigma_new - Sigma, "fro") / torch.norm(Sigma, "fro")
Sigma = Sigma_new
self._update_pbar(delta)
if delta < self.tol:
break
else:
if self.verbose:
warnings.warn("M-estimation didn't converge.")
if self.verbose:
self.pbar.close()
if self.normalize is not None:
Sigma_new = self.normalize(Sigma_new)
# For numerical stability
return 0.5 * (Sigma_new + Sigma_new.transpose(-2, -1))