Utils

This module contains a set of utility functions used by the other modules.

anotherspdnet.utils.initialize_weights_stiefel(W: geoopt.tensor.ManifoldParameter, seed: int | None = None) geoopt.tensor.ManifoldParameter[source]

Initialize the weights as being on the Stiefel manifold.

Theorem 2.2.1 in Chikuse (2003): statistics on special manifolds. TODO: Verify that this is correct.

Parameters:
  • W (geoopt.tensor.ManifoldParameter of shape (..., n, k)) – weights to be initialized. n should be greater than k.

  • seed (int, optional) – random seed for reproducibility. If None, no seed is used.

Returns:

W – initialized weights (same object as input with data changed)

Return type:

geoopt.tensor.ManifoldParameter

anotherspdnet.utils.initialize_weights_sphere(W: geoopt.tensor.ManifoldParameter, seed: int | None = None) geoopt.tensor.ManifoldParameter[source]

Initialize the weights as being on the sphere manifold.

Parameters:
  • W (geoopt.tensor.ManifoldParameter of shape (..., n, k)) – weights to be initialized

  • seed (int, optional) – random seed for reproducibility. If None, no seed is used.

Returns:

W – initialized weights (same object as input with data changed)

Return type:

geoopt.tensor.ManifoldParameter

anotherspdnet.utils.symmetrize(X: torch.Tensor) torch.Tensor[source]

Symmetrize a tensor along the last two dimensions.

Parameters:

X (torch.Tensor) – tensor of shape (…, n_features, n_features)

Returns:

sym_X – symmetrized tensor of shape (…, n_features, n_features)

Return type:

torch.Tensor

anotherspdnet.utils.construct_eigdiff_matrix(eigvals: torch.Tensor) torch.Tensor[source]

Constructs the matrix of the inverse pairwise differences between eigenvalues on the off-diagonal and 0 on the diagonal.

Parameters:

eigvals (torch.Tensor of shape (..., n_features)) – eigenvalues of the SPD matrices

Returns:

eigdiff_matrix – matrix of the inverse pairwise differences between eigenvalues

Return type:

torch.Tensor of shape (…, n_features, n_features)

anotherspdnet.utils.zero_offdiag(X: torch.Tensor) torch.Tensor[source]

Sets the off-diagonal elements of a tensor to 0.

Parameters:

X (torch.Tensor) – tensor of shape (…, n_features, n_features)

Returns:

X_zero – tensor of shape (…, n_features, n_features) with 0 on the off-diagonal

Return type:

torch.Tensor

anotherspdnet.utils.nd_tensor_to_3d(input: torch.Tensor) torch.Tensor[source]

Converts an n-dimensional tensor to a 3D tensor by fusing the first dimensions.

Parameters:

input (torch.Tensor) – tensor of shape (…, n, m)

Returns:

output – tensor of shape (d, n, m), where d is the product of the first dimensions of input

Return type:

torch.Tensor

anotherspdnet.utils.threed_tensor_to_nd(input: torch.Tensor, shape: Tuple) torch.Tensor[source]

Converts a 3D tensor to an n-dimensional tensor by splitting the first dimension.

Parameters:
  • input (torch.Tensor) – tensor of shape (d, n, m)

  • shape (torch.Size) – shape of the output tensor

Returns:

output – tensor of shape (…, n, m)

Return type:

torch.Tensor