import math

import torch
from torch import nn

from .utils import *


class MVLinear(nn.Module):
    def __init__(
        self,
        algebra,
        in_features,
        out_features,
        subspaces=True, 
        # whether use a different weight matrix for every subspace
        use_eigenvalue=False,
        # whether use eigenvalue to parameterize weight
        eigval_encoder=None,
        restrict_grade=None,
        # whether to restrict grade of multivectors
        bias=True,
    ):
        """
        r_i^(0) <- W_0 @ r_i^(0) + b_0 

        r_i^(k) <- W_k @ r_i^(k)    (for k != 0)

        (N_nodes, F_in, ..., 2**N) -> (N_nodes, F_out, ..., 2**N).

        When `use_eigenvalue=True`, W_0, ..., W_k will be multiplied by
        eigenvalues.

        eigenvalues should be of shape (num_eigenspaces, ), where 
        `N_nodes % num_eigenspaces == 0`.

        When `subspaces=False`, W_0 = ... = W_k = W.

        When `bias=False`, b_0 = 0.
        """
        super().__init__()

        self.algebra = algebra
        self.weights_shared_algebras = []
        self.in_features = in_features
        self.out_features = out_features
        self.subspaces = subspaces
        self.use_eigenvalue = use_eigenvalue

        if subspaces:
            if restrict_grade is not None:
                restrict_grade = min(restrict_grade, algebra.num_bases + 1)
                self.restrict_grade = restrict_grade
                self._weight = nn.Parameter(
                    torch.empty(out_features, in_features, restrict_grade)
                )
                self._weight_padding = torch.zeros(out_features, in_features, 
                                 algebra.num_bases - restrict_grade + 1)
            else:
                self.restrict_grade = None
                self.weight = nn.Parameter(
                    torch.empty(out_features, in_features, algebra.num_bases + 1)
                )
            self._forward = self._forward_subspaces
        else:
            self.restrict_grade = None
            self.weight = nn.Parameter(torch.empty(out_features, in_features))

        if self.use_eigenvalue:
            assert eigval_encoder is not None
            self.eigval_encoder = eigval_encoder

        if bias:
            self.bias = nn.Parameter(torch.empty(1, out_features, 1))
            self.b_dims = (0,)
        else:
            self.register_parameter("bias", None)
            self.b_dims = ()

        self.reset_parameters()

    def to(self, device):
        super().to(device)
        if hasattr(self, "_weight_padding"):
            self._weight = self._weight.to(device)
            self._weight_padding = self._weight_padding.to(device)

    def reset_parameters(self):
        if self.restrict_grade is not None:
            torch.nn.init.normal_(self._weight, std=1 / math.sqrt(self.in_features))
        else:
            torch.nn.init.normal_(self.weight, std=1 / math.sqrt(self.in_features))

        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def share_weights_with(self, algebra):
        """
        Allow weight sharing with multiple algebras.
        """
        if algebra.dim >= self.algebra.dim:
            raise RuntimeError("Expect 'self.weights_shared_algebras' to contain "
                               "only smaller algebras.")
        self.weights_shared_algebras.append(algebra)
        return self

    def _forward(self, input):
        return torch.einsum("bm...i, nm->bn...i", input, self.weight)

    def _forward_subspaces(self, input):
        if hasattr(self, "_weight_padding"):
            _weight = torch.cat([self._weight, self._weight_padding], dim=-1)
        else:
            _weight = self.weight

        if input.shape[-1] == self.algebra.dim:
            weight = _weight[..., self.algebra.grades]
            # now weight is (out_features, in_features, 2**N)
            return torch.einsum("bm...i, nmi->bn...i", input, weight)

        for algebra in self.weights_shared_algebras:
            if input.shape[-1] == algebra.dim:
                assert input.shape[-1] < self.algebra.dim
                # requires cropping weight
                weight = _weight[..., :round(math.log2(input.shape[-1])) + 1
                                ][..., algebra.grades]
                return torch.einsum("bm...i, nmi->bn...i", input, weight)
        
        raise RuntimeError(f"No size-matching algebra found in "
                           f"'self.weights_shared_algebras', got size {input.shape[-1]}.")

    def forward(self, input, eigenvalue=None):
        result = self._forward(input)
        # result is (batch, out_features, ..., 2**N)

        if self.use_eigenvalue:
            assert eigenvalue is not None
            assert result.shape[0] % eigenvalue.numel() == 0
            repeats = result.shape[0] // eigenvalue.numel()
            result *= unsqueeze_like(
                self.eigval_encoder(eigenvalue.reshape(-1
                ).repeat_interleave(repeats).unsqueeze(-1)),
                result, dim=2)

        if self.bias is not None:
            if result.shape[-1] == self.algebra.dim:
                bias = self.algebra.embed(self.bias, self.b_dims)
                result += unsqueeze_like(bias, result, dim=2)
            else:
                for algebra in self.weights_shared_algebras:
                    if result.shape[-1] == algebra.dim:
                        bias = algebra.embed(self.bias, self.b_dims)
                        result += unsqueeze_like(bias, result, dim=2)

        return result

