import torch
from torch import nn

from .base import Encoder


class SONEncoder(Encoder):
    """Deterministic SO(n) encoder / decoder"""

    def __init__(self, coeffs=None, input_features=2, output_features=2, coeff_std=0.05):
        super().__init__(input_features, output_features)

        # Coefficients
        d = self.output_features * (self.output_features - 1) // 2
        if coeffs is None:
            self.coeffs = nn.Parameter(torch.zeros((d,)))  # (d)
            nn.init.normal_(self.coeffs, std=coeff_std)
        else:
            assert coeffs.shape == (d,)
            self.coeffs = nn.Parameter(coeffs)

        # Generators
        self.generators = torch.zeros((d, self.output_features, self.output_features))

    def forward(self, inputs, deterministic=False):
        """Given observed data, returns latent representation; i.e. encoding."""
        z = torch.einsum("ij,bj->bi", self._rotation_matrix(), inputs)
        logdet = torch.zeros([])
        return z, logdet

    def inverse(self, inputs, deterministic=False):
        """Given latent representation, returns observed version; i.e. decoding."""
        x = torch.einsum("ij,bj->bi", self._rotation_matrix(inverse=True), inputs)
        logdet = torch.zeros([])
        return x, logdet

    def _rotation_matrix(self, inverse=False):
        """
        Low-level function to generate an element of SO(n) by exponentiating the Lie algebra
        (skew-symmetric matrices)
        """

        o = torch.zeros(self.output_features, self.output_features, device=self.coeffs.device)
        i, j = torch.triu_indices(self.output_features, self.output_features, offset=1)
        if inverse:
            o[i, j] = -self.coeffs
            o.T[i, j] = self.coeffs
        else:
            o[i, j] = self.coeffs
            o.T[i, j] = -self.coeffs
        a = torch.matrix_exp(o)
        return a