from typing import Optional
import math
import torch

from pado.core import PadoModule

__all__ = ["SinusoidalPositionalEncoding", "BidirectionalSinusoidalPositionalEncoding",
           "_SinusoidalPositionalEncodingBase"]


class _SinusoidalPositionalEncodingBase(PadoModule):

    def __init__(self,
                 embed_dim: int,
                 clamp_length: Optional[int] = None, *,
                 inverse: bool = False):
        super().__init__()

        if embed_dim % 2 != 0:
            raise ValueError(f"Currently SinusoidalPE only supports even embed_dim, got {embed_dim}.")
        self.embed_dim = embed_dim
        self.clamp_length = clamp_length
        self.inverse = inverse

        # inv_freq = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0) / embed_dim))
        inv_freq = torch.exp(
            torch.arange(0.0, embed_dim, 2.0, dtype=torch.float32).mul_(-math.log(10000.0) / embed_dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)  # (embed_dim // 2,)

    def forward(self, length: int) -> torch.Tensor:
        raise NotImplementedError

    def extra_repr(self) -> str:
        s = f"{self.embed_dim}"
        if self.clamp_length is not None:
            s += f", clamp_length={self.clamp_length}"
        if self.inverse:
            s += f", inverse=True"
        return s


class SinusoidalPositionalEncoding(_SinusoidalPositionalEncodingBase):

    def forward(self, length: int) -> torch.Tensor:
        """
        Sinusoidal PE, for left-to-right or right-to-left context.
        :param length:    sequence length to generate
        :return:          (1, seq_len, embed_dim)

        Generated PE will be in indices of:
                [0, 1, 2, 3, ... s-1] if not inverse
                [s-1, s-2, ...  1, 0] if inverse
        """
        # convert indices to position
        with torch.no_grad():
            device = self.inv_freq.device
            emb = torch.zeros(length, self.embed_dim, dtype=torch.float32, device=device)
            pos = torch.arange(0, length, dtype=torch.float32, device=device)
            if self.clamp_length is not None:
                pos = pos.clamp_max_(self.clamp_length - 1)
            dot = torch.outer(pos, self.inv_freq)
            emb[:, 0::2] = torch.sin(dot)
            emb[:, 1::2] = torch.cos(dot)

            if self.inverse:
                # [0, 1, 2, 3, ... seq_len - 1] -> [seq_len - 1, ... 1, 0]
                emb = torch.flipud(emb).contiguous()

            emb = emb.unsqueeze(0)  # (1, seq_len, embed_dim)
        return emb


class BidirectionalSinusoidalPositionalEncoding(_SinusoidalPositionalEncodingBase):

    def forward(self, length: int) -> torch.Tensor:
        """
        Sinusoidal PE, for bidirectional context.
        :param length:      sequence length to generate
        :return:            (1, 2 * seq_len - 1, embed_dim)

        Generated PE will be in indices of:
                [s-1, s-2, ... 1, 0, -1, -2, ... -s+1] if not inverse
                [-s+1, -s+2, ... 0, 1, 2, 3, ... s-1]  if inverse (shouldn't required for Bidirectional)
        """
        # convert indices to position
        with torch.no_grad():
            device = self.inv_freq.device
            emb = torch.zeros(length * 2 - 1, self.embed_dim, dtype=torch.float32, device=device)
            pos = torch.arange(length - 1, -length, -1, dtype=torch.float32, device=device)
            if self.clamp_length is not None:
                pos = pos.clamp_(-self.clamp_length + 1, self.clamp_length - 1)

            dot = torch.outer(pos, self.inv_freq)
            emb[:, 0::2] = torch.sin(dot)
            emb[:, 1::2] = torch.cos(dot)

            if self.inverse:
                emb = torch.flipud(emb).contiguous()

            emb = emb.unsqueeze(0)  # (1, 2 * seq_len - 1, embed_dim)
        return emb
