import geoopt
import torch
from torch import nn

from manifolds import StiefelT


class Concatenate(nn.Module):
    def __init__(self, *modules: nn.Module):
        super().__init__()
        self._modules_ls = nn.ModuleList(modules)

    def forward(self) -> torch.Tensor:
        return torch.cat([m() for m in self._modules_ls])


class TorchEmbeddingStiefelParameter(nn.Module):
    def __init__(
        self,
        data: torch.Tensor | nn.Parameter,
        *,
        num_variables: int,
        num_repetitions: int,
        num_units: int,
        num_states: int,
        fold_idx: torch.Tensor,
    ):
        # For orthogonal non-structured squared PCs we would like
        # to parameterize a batch of (semi-)unitary matrices of shape (V, S * K, N),
        # where V is the number of variables
        #       S is the number of repetitions (i.e., number of embedding layers per variable)
        #       K is the number of units
        #       N is the number of states (e.g., 256 for grayscale pixels)
        #   and S * K <= N
        # Moreover, 'data' is a tensor of shape (V * S, K, N) that is here used to initialize
        # the wrapper. We reshape it to be (V, S * K, N) as to accommodate the optimizers over
        # the Stiefel manifold.
        assert num_repetitions * num_units <= num_states
        assert fold_idx.shape[0] == num_variables * num_repetitions
        assert fold_idx.shape[1] == 1
        super().__init__()
        self.num_variables = num_variables
        self.num_repetitions = num_repetitions
        self.num_units = num_units
        self.num_states = num_states
        # data: (V * S, K, N) -> order fold dimension -> reshape to (V, S * K, N)
        with torch.no_grad():
            data = data[fold_idx.squeeze(dim=1).argsort(stable=True)]
            data = data.view(num_variables, num_repetitions * num_units, num_states)
        self._parameter = geoopt.ManifoldParameter(data, manifold=StiefelT())
        with torch.no_grad():
            self._parameter.proj_()
        # After reshaping the above parameter to be of shape (F, K, N) (see forward method),
        # we need to re-order the matrices along axis=0 as to recover the fold index vector
        # of the embedding layer.
        # To do this re-ordering, we simply need to invert the argsort of the fold index vector
        # i.e., call argsort twice
        self._fold_permutation: torch.Tensor
        self.register_buffer('_fold_permutation', fold_idx.squeeze(dim=1).argsort(stable=True).argsort(stable=True))

    def forward(self) -> torch.Tensor:
        # p: (V, S * K, N)
        p = self._parameter
        # Reshape from (V, S * K, N) to (F, K, N), where F = V * S
        q = p.view(self.num_variables * self.num_repetitions, self.num_units, self.num_states)
        # Permute the folds as to match the fold idx of the embedding layer
        q = q[self._fold_permutation]
        return q
