import math

import torch
from torch import nn

from .linear import MVLinear
from .normalization import NormalizationLayer


class ElementWiseGP(nn.Module):
    def __init__(
        self, 
        algebra, 
        features, 
        include_first_order=True, 
        use_eigenvalue=False, 
        eigval_encoder=None,
        restrict_grade=None,
        share_weights_from=None,
        normalization_init=0
    ):
        """
        x * Norm(MVLinear1(x)) + MVLinear2(x)
        
        (N_nodes, F, ..., 2**N) -> (N_nodes, F, ..., 2**N)

        If `include_first_order=False`, MVLinear2 = 0.
        """
        super().__init__()

        self.algebra = algebra
        self.features = features
        self.include_first_order = include_first_order
        self.use_eigenvalue = use_eigenvalue

        if normalization_init is not None:
            self.normalization = NormalizationLayer(
                algebra, features, normalization_init, restrict_grade
            )
        else:
            self.normalization = nn.Identity()
        
        if share_weights_from is None:
            self.linear_right = MVLinear(algebra, features, features, bias=False, 
                                        use_eigenvalue=use_eigenvalue,
                                        eigval_encoder=eigval_encoder,
                                        restrict_grade=restrict_grade)
            if include_first_order:
                self.linear_left = MVLinear(algebra, features, features, bias=True, 
                                            use_eigenvalue=use_eigenvalue,
                                            eigval_encoder=eigval_encoder,
                                            restrict_grade=restrict_grade)
        else:
            if include_first_order:
                self.linear_right = share_weights_from.linear_right.share_weights_with(self.algebra)
                self.linear_left = share_weights_from.linear_left.share_weights_with(self.algebra)
            else:
                self.linear_right = share_weights_from.linear_right.share_weights_with(self.algebra)

        self.product_paths = algebra._geometric_product_paths
        if restrict_grade is not None:
            self.restrict_grade = min(restrict_grade, algebra.num_bases + 1)
            self.weight = nn.Parameter(torch.empty(features,
                                                    self.product_paths[:self.restrict_grade,
                                                                       :self.restrict_grade,
                                                                       :self.restrict_grade].sum()))
        else:
            self.restrict_grade = None
            self.weight = nn.Parameter(torch.empty(features, self.product_paths.sum()))

        self.reset_parameters()

    def reset_parameters(self):
        if self.restrict_grade is not None:
            torch.nn.init.normal_(self.weight, std=1 / (math.sqrt(self.restrict_grade)))
        else:
            torch.nn.init.normal_(self.weight, std=1 / (math.sqrt(self.algebra.num_bases + 1)))

    def to(self, device):
        super().to(device)
        self.normalization.to(device)
        self.linear_right.to(device)
        if self.include_first_order:
            self.linear_left.to(device)

    def _get_weight(self):
        weight = torch.zeros(
            self.features,
            *self.product_paths.size(),
            dtype=self.weight.dtype,
            device=self.weight.device,
        )
        if self.restrict_grade is not None:
            product_paths = torch.zeros_like(self.product_paths,
                                             dtype=self.product_paths.dtype,
                                             device=self.product_paths.device)
            product_paths[:self.restrict_grade,
                          :self.restrict_grade,
                          :self.restrict_grade] = self.product_paths[
                                         :self.restrict_grade,
                                         :self.restrict_grade,
                                         :self.restrict_grade]

            weight[:, product_paths] = self.weight
        else:
            weight[:, self.product_paths] = self.weight
        weight_repeated = weight[..., self.algebra.grades, :, :
                               ][..., self.algebra.grades, :
                               ][..., self.algebra.grades]
        return self.algebra.table * weight_repeated

    def forward(self, input, eigenvalue=None):
        input_right = self.linear_right(input, eigenvalue)
        input_right = self.normalization(input_right)

        weight = self._get_weight()

        if self.include_first_order:
            return (
                self.linear_left(input, eigenvalue)
                + torch.einsum("bn...i, nijk, bn...k -> bn...j", input, weight, input_right)
            ) / math.sqrt(2)

        else:
            return torch.einsum("bn...i, nijk, bn...k -> bn...j", input, weight, input_right)

