import torch
from torch import nn

from .utils import unsqueeze_like

EPS = 1e-6


class NormalizationLayer(nn.Module):
    def __init__(self, algebra, features, init: float = 0, restrict_grade = None):
        """
        (N_nodes, F_in, ..., 2**N) -> (N_nodes, F_in, ..., 2**N)
        """
        super().__init__()
        self.algebra = algebra
        self.in_features = features

        # (F_in, N+1)
        if restrict_grade is not None:
            restrict_grade = min(restrict_grade, algebra.num_bases + 1)
            self._a = nn.Parameter(torch.zeros(self.in_features, restrict_grade) + init)
            self._a_padding = torch.zeros(self.in_features, algebra.num_bases - restrict_grade + 1)
        else:
            self.a = nn.Parameter(torch.zeros(self.in_features, algebra.num_bases + 1) + init)

    def to(self, device):
        super().to(device)
        if hasattr(self, "_a_padding"):
            self._a = self._a.to(device)
            self._a_padding = self._a_padding.to(device)

    def forward(self, input):
        assert input.shape[1] == self.in_features

        if hasattr(self, "_a_padding"):
            _a = torch.cat([self._a, self._a_padding], dim=-1)
        else:
            _a = self.a

        # (N_nodes, F_in, ..., N+1)
        norms = torch.cat(self.algebra.norms(input), dim=-1)
        s_a = unsqueeze_like(torch.sigmoid(_a), norms[0], dim=1)
        norms = s_a * (norms - 1) + 1  # Interpolates between 1 and the norm.
        norms = norms[..., self.algebra.grades]
        normalized = input / (norms + EPS)

        return normalized
