import math

import torch
from torch import nn

from .linear import MVLinear
from .normalization import NormalizationLayer


class FullyConnectedGP(nn.Module):
    def __init__(
        self,
        algebra,
        in_features,
        out_features,
        include_first_order=True,
        use_eigenvalue_1=False,
        use_eigenvalue_2=False, 
        eigval_encoder_1=None,
        eigval_encoder_2=None,
        restrict_grade=None,
        share_weights_from=None,
        normalization_init=0,
    ):
        """
        FC(x * Norm(MVLinear1(x))) + MVLinear2(x)
        
        (N_nodes, F_in, ..., 2**N) -> (N_nodes, F_out, ..., 2**N)

        If `include_first_order=False`, MVLinear2 = 0.
        """
        super().__init__()

        self.algebra = algebra
        self.in_features = in_features
        self.out_features = out_features
        self.include_first_order = include_first_order
        self.use_eigenvalue_1 = use_eigenvalue_1
        self.use_eigenvalue_2 = use_eigenvalue_2

        if normalization_init is not None:
            self.normalization = NormalizationLayer(
                algebra, in_features, normalization_init, restrict_grade
            )
        else:
            self.normalization = nn.Identity()

        if share_weights_from is None:
            self.linear_right = MVLinear(algebra, in_features, in_features, bias=False,
                                        use_eigenvalue=use_eigenvalue_1,
                                        eigval_encoder=eigval_encoder_1,
                                        restrict_grade=restrict_grade)
            if include_first_order:
                self.linear_left = MVLinear(algebra, in_features, out_features, bias=True,
                                            use_eigenvalue=use_eigenvalue_2,
                                            eigval_encoder=eigval_encoder_2,
                                            restrict_grade=restrict_grade)
        else:
            if include_first_order:
                self.linear_right = share_weights_from.linear_right.share_weights_with(self.algebra)
                self.linear_left = share_weights_from.linear_left.share_weights_with(self.algebra)
            else:
                self.linear_right = share_weights_from.linear_right.share_weights_with(self.algebra)

        self.product_paths = algebra._geometric_product_paths
        if restrict_grade is not None:
            self.restrict_grade = min(restrict_grade, algebra.num_bases + 1)
            self.weight = nn.Parameter(torch.empty(out_features, in_features,
                                                    self.product_paths[:self.restrict_grade,
                                                                       :self.restrict_grade,
                                                                       :self.restrict_grade].sum()))
        else:
            self.restrict_grade = None
            self.weight = nn.Parameter(torch.empty(out_features, in_features, 
                                                   self.product_paths.sum()))

        self.reset_parameters()

    def reset_parameters(self):
        if self.restrict_grade is not None:
            torch.nn.init.normal_(
                self.weight,
                std=1 / math.sqrt(self.in_features * self.restrict_grade),
            )
        else:
            torch.nn.init.normal_(
                self.weight,
                std=1 / math.sqrt(self.in_features * (self.algebra.num_bases + 1)),
            )

    def _get_weight(self):
        weight = torch.zeros(
            self.out_features,
            self.in_features,
            *self.product_paths.size(),
            dtype=self.weight.dtype,
            device=self.weight.device,
        )
        if self.restrict_grade is not None:
            product_paths = torch.zeros_like(self.product_paths,
                                             dtype=self.product_paths.dtype,
                                             device=self.product_paths.device)
            product_paths[:self.restrict_grade,
                          :self.restrict_grade,
                          :self.restrict_grade] = self.product_paths[
                                         :self.restrict_grade,
                                         :self.restrict_grade,
                                         :self.restrict_grade]

            weight[:, :, product_paths] = self.weight
        else:
            weight[:, :, self.product_paths] = self.weight
        weight_repeated = weight[..., self.algebra.grades, :, :
                               ][..., self.algebra.grades, :
                               ][..., self.algebra.grades]
        return self.algebra.table * weight_repeated
    
    def to(self, device):
        super().to(device)
        self.normalization.to(device)
        self.linear_right.to(device)
        if self.include_first_order:
            self.linear_left.to(device)

    def forward(self, input, eigenvalue=None):
        input_right = self.linear_right(input, eigenvalue)
        input_right = self.normalization(input_right)

        weight = self._get_weight()

        if self.include_first_order:
            return (
                self.linear_left(input, eigenvalue)
                + torch.einsum("bn...i, mnijk, bn...k -> bm...j", input, weight, input_right)
            ) / math.sqrt(2)
        else:
            return torch.einsum("bn...i, mnijk, bn...k -> bm...j", input, weight, input_right)

        