import numpy as np
import torch
import torch.nn as nn
from src.utils import fourier_position_encoder
from src.nn.fusion import fusion_factory
from src.nn.mlp import FFN


__all__ = [
    'CatInjection', 'AdditiveInjection', 'AdditiveMLPInjection',
    'FourierInjection', 'LearnableFourierInjection']


class BasePositionalInjection(nn.Module):
    def __init__(self, dim=None, x_dim=None, fusion='additive', **kwargs):
        """Base class for positional information injection. Takes care
        fusion with a potential embedding vector.

        Child classes are expected to overwrite the `_encode` method.

        :param dim: int
            Positional encoding dimension
        :param x_dim: int
            If provided, the input feature embeddings will undergo a
            linear projection ’x_dim -> dim’ before being fused with the
            positional encodings
        :param fusion: str
            Fusion mechanism to merge positional encodings with feature
            embeddings
        """
        super().__init__()
        self.dim = dim
        self.fusion = fusion
        self.proj = nn.Identity() if x_dim is None or dim is None \
            else nn.Linear(x_dim, dim)

        # Fusion operator
        self.fusion = fusion_factory(fusion)

    def _encode(self):
        raise NotImplementedError

    def forward(self, pos, x):
        if x is not None:
            x = self.proj(x)
        return self.fusion(self._encode(pos), x)


class CatInjection(BasePositionalInjection):
    def __init__(self, **kwargs):
        """Simple child class of BasePositionalInjection equivalent to
        a CatFusion.
        """
        super().__init__(dim=None, x_dim=None, fusion='cat')

    def _encode(self, pos):
        return pos


class AdditiveInjection(BasePositionalInjection):
    def __init__(self, **kwargs):
        """Simple child class of BasePositionalInjection equivalent to
        an AdditiveFusion.
        """
        super().__init__(dim=None, x_dim=None, fusion='additive')

    def _encode(self, pos):
        return pos


class AdditiveMLPInjection(BasePositionalInjection):
    def __init__(self, dim=None, **kwargs):
        """Simple child class of BasePositionalInjection equivalent to
        an MLP followed by AdditiveFusion.
        """
        super().__init__(dim=dim, x_dim=None, fusion='additive')

        self.ffn = FFN(3, out_dim=self.dim, activation=nn.LeakyReLU())

    def _encode(self, pos):
        return self.ffn(pos)


class FourierInjection(BasePositionalInjection):
    def __init__(
            self, dim=None, x_dim=None, fusion='additive', f_min=1e-1,
            f_max=1e1, **kwargs):
        """Convert [N, M] M-dimensional positions into [N, dim] encodings
        using sine and cosine decomposition along each axis. Expects dim
        to be a multiple of 2*M, for each of the M-dimensions to have
        access to the same number of encoding dimensions.

        Input positions are expected to be normalized in [-1, 1] before
        encoding. This operation is important, since passing positions
        outside this range will result in ambiguities where two distinct
        positions have the same encoding.

        :param dim: positional encoding dimension
        """
        assert dim is not None
        super().__init__(dim=dim, x_dim=x_dim, fusion=fusion, **kwargs)
        self.f_min = f_min
        self.f_max = f_max

    def _encode(self, pos):
        return fourier_position_encoder(
            pos, self.dim, f_min=self.f_min, f_max=self.f_max)


class LearnableFourierInjection(BasePositionalInjection):
    def __init__(self, M: int, F_dim: int, H_dim: int, D: int, gamma: float):
        """Learnable Fourier Features from:
            https://arxiv.org/pdf/2106.02795.pdf (Algorithm 1)

        Implementation of Algorithm 1: Compute the Fourier feature positional encoding of a multi-dimensional position
        Computes the positional encoding of a tensor of shape [N, M]
        :param M: each point has a M-dimensional positional values
        :param F_dim: depth of the Fourier feature dimension
        :param H_dim: hidden layer dimension
        :param D: positional encoding dimension
        :param gamma: parameter to initialize Wr
        """
        super().__init__()
        self.M = M
        self.F_dim = F_dim
        self.H_dim = H_dim
        self.D = D
        self.gamma = gamma

        # Projection matrix on learned lines (used in eq. 2)
        self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
        # MLP (GeLU(F @ W1 + B1) @ W2 + B2 (eq. 6)
        self.ffn = FFN(
            self.F_dim, hidden_dim=self.H_dim, out_dim=self.D,
            activation=nn.GELU(), drop=None)
        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)

    def forward(self, x):
        """Produce positional encodings from x.

        :param x: tensor of shape [N, G, M] that represents N positions where each position is in the shape of [G, M],
                  where G is the positional group and each group has M-dimensional positional values.
                  Positions in different positional groups are independent

        :return: positional encoding for X
        """
        N, G, M = x.shape
        # Step 1. Compute Fourier features (eq. 2)
        projected = self.Wr(x)
        cosines = torch.cos(projected)
        sines = torch.sin(projected)
        F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1)
        # Step 2. Compute projected Fourier features (eq. 6)
        Y = self.ffn(F)
        # Step 3. Reshape to x's shape
        PEx = Y.reshape((N, self.D))
        return PEx
