Utils
This module contains a set of utility functions used by the other modules.
- anotherspdnet.utils.initialize_weights_stiefel(W: geoopt.tensor.ManifoldParameter, seed: int | None = None) geoopt.tensor.ManifoldParameter [source]
Initialize the weights as being on the Stiefel manifold.
Theorem 2.2.1 in Chikuse (2003): statistics on special manifolds. TODO: Verify that this is correct.
- Parameters:
W (geoopt.tensor.ManifoldParameter of shape (..., n, k)) – weights to be initialized. n should be greater than k.
seed (int, optional) – random seed for reproducibility. If None, no seed is used.
- Returns:
W – initialized weights (same object as input with data changed)
- Return type:
geoopt.tensor.ManifoldParameter
- anotherspdnet.utils.initialize_weights_sphere(W: geoopt.tensor.ManifoldParameter, seed: int | None = None) geoopt.tensor.ManifoldParameter [source]
Initialize the weights as being on the sphere manifold.
- Parameters:
W (geoopt.tensor.ManifoldParameter of shape (..., n, k)) – weights to be initialized
seed (int, optional) – random seed for reproducibility. If None, no seed is used.
- Returns:
W – initialized weights (same object as input with data changed)
- Return type:
geoopt.tensor.ManifoldParameter
- anotherspdnet.utils.symmetrize(X: torch.Tensor) torch.Tensor [source]
Symmetrize a tensor along the last two dimensions.
- Parameters:
X (torch.Tensor) – tensor of shape (…, n_features, n_features)
- Returns:
sym_X – symmetrized tensor of shape (…, n_features, n_features)
- Return type:
torch.Tensor
- anotherspdnet.utils.construct_eigdiff_matrix(eigvals: torch.Tensor) torch.Tensor [source]
Constructs the matrix of the inverse pairwise differences between eigenvalues on the off-diagonal and 0 on the diagonal.
- Parameters:
eigvals (torch.Tensor of shape (..., n_features)) – eigenvalues of the SPD matrices
- Returns:
eigdiff_matrix – matrix of the inverse pairwise differences between eigenvalues
- Return type:
torch.Tensor of shape (…, n_features, n_features)
- anotherspdnet.utils.zero_offdiag(X: torch.Tensor) torch.Tensor [source]
Sets the off-diagonal elements of a tensor to 0.
- Parameters:
X (torch.Tensor) – tensor of shape (…, n_features, n_features)
- Returns:
X_zero – tensor of shape (…, n_features, n_features) with 0 on the off-diagonal
- Return type:
torch.Tensor
- anotherspdnet.utils.nd_tensor_to_3d(input: torch.Tensor) torch.Tensor [source]
Converts an n-dimensional tensor to a 3D tensor by fusing the first dimensions.
- Parameters:
input (torch.Tensor) – tensor of shape (…, n, m)
- Returns:
output – tensor of shape (d, n, m), where d is the product of the first dimensions of input
- Return type:
torch.Tensor
- anotherspdnet.utils.threed_tensor_to_nd(input: torch.Tensor, shape: Tuple) torch.Tensor [source]
Converts a 3D tensor to an n-dimensional tensor by splitting the first dimension.
- Parameters:
input (torch.Tensor) – tensor of shape (d, n, m)
shape (torch.Size) – shape of the output tensor
- Returns:
output – tensor of shape (…, n, m)
- Return type:
torch.Tensor