from enum import Enum
from torch import nn
import torch
import math

class PosEnc(str, Enum):
    FIXED_TRIG_FUNCS = 1 # Classical, described in paper Attention is all you need
    FIXED_ONE_HOT_CONCAT = 2 # One hot encoding of dimension seq_len concatenated to the embedding vector
    FIXED_ONE_HOT_ADD = 3 # One hot encoding of dimension seq_len added to embedding vector
    LEARNABLE_NON_AUG = 4 # Learnable non-augmented PE, describe in Gabriel Peyré's paper, How do Transformers Perform In-Context Autoregressive Learning?

MAX_LEN = int(5000)

class FixedPosEncTrig(nn.Module):
    def __init__(self, model_dim, max_len=MAX_LEN, device='cpu'): # max_len refers to the "block size" or seq_len (how many tokens in a seq)
        super().__init__()

        position = torch.arange(max_len) # create seq from 0 ... to max_len 
        position = position.unsqueeze(1) # make it from size (maxlen, ) to (maxlen, 1) 
        
        # needs to rpoduce 1/n^(2i/d) for 0 <= 1 <= model_dim/2, where n = 10000 (from attention is all you need)
        div_term = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
        
        
        pos_enc = torch.zeros(1, max_len, model_dim) # create a tensor of size 1 x max_len x model_dim
        pos_enc[0, :, 0::2] = torch.sin(position * div_term)[:, 0 :  math.ceil(model_dim/2.0)] # notation 0::2 - from 0 to the end in steps of 2 (every even position)
        pos_enc[0, :, 1::2] = torch.cos(position * div_term)[:, 0 :  math.floor(model_dim/2.0)]

        # If you have parameters in your model, which should be saved and restored in the state_dict, 
        # but not trained by the optimizer, you should register them as buffers.
        # Buffers won’t be returned in model.parameters()
        self.register_buffer('pos_enc', pos_enc)
        self.device = device

    def forward(self, X):
        # assuming X has tokens on columns, and X.size(1) is number of tokens (seq len)
        X = X + self.pos_enc[:, : X.size(1), :]
        return X
    
class OneHotPosEnc(nn.Module):
    def __init__(self, max_len=MAX_LEN, concat=False, device='cpu'):
        """
        Initializes the PositionalEncodings class.
        Args:
            max_len (int, optional): The maximum length of the positional encodings. Defaults to 5000.
            concat (bool, optional): Whether to concatenate the positional encodings with the input embeddings. Defaults to False.
        """
        super().__init__()

        self.max_len = max_len
        self.concat = concat
        self.device = device
        
    def forward(self, x):
        #TODO: i'll have to send all this crap to the right device
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_length, embedding_dim]``
        """
        pos = torch.arange(x.size(1), device=self.device).unsqueeze(0).repeat(x.size(0),1)
        pos = F.one_hot(pos, num_classes=self.max_len).float()
        if self.concat:
            x = torch.cat((x, pos), dim=-1)
        else:
            x = x + pos
        return x
    
class LearnNonAugPosEnc(nn.Module):
    def __init__(self, max_len=MAX_LEN, concat=False, device='cpu'):
        """
        !!!!!! This is a bit of a dummy class.
        
        Initializes the PositionalEncodings class.
        Args:
            max_len (int, optional): The maximum length of the positional encodings. Defaults to 5000.
            concat (bool, optional): Whether to concatenate the positional encodings with the input embeddings. Defaults to False.
        """
        super().__init__()

        self.max_len = max_len
        self.concat = concat
        self.device = device
        
    def forward(self, x):
        #For learnable non-augmented PE, the input should not be changed
        return x