import torch
import torch.nn as nn
import torch.nn.functional as F
from models.custom_layers import get_residual_layer, act_registry
import math

from functools import partial

class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0)
        self.register_buffer("pos_encoding",pos_encoding)

    def forward(self, x: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(x + self.pos_encoding[:, :x.shape[1], :])

class LearnablePositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Initialize dropout
        self.dropout = nn.Dropout(p=dropout_p)
        
        self.pos_encoding = nn.Parameter(torch.zeros(max_len, dim_model))
        
        nn.init.normal_(self.pos_encoding, mean=0, std=dim_model**-0.5)

    def forward(self, x):
        pos_encoding = self.pos_encoding[:x.shape[1], :].unsqueeze(0)
        
        x = x + pos_encoding
        return self.dropout(x)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, dropout, num_encoder_layers, num_decoder_layers, nhead, dim_feedforward_multiplier, max_len):
        super(TransformerBlock, self).__init__()
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dropout = dropout, dim_feedforward=dim_feedforward_multiplier*d_model, batch_first=True)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)

    def forward(self, x, batch_dt = None):
        # autoregressive with causal masks in src and tgt
        x = self.pos_encoder(x)
        causal_mask = self.transformer.generate_square_subsequent_mask(x.size(1)).to(x.device)
        return self.transformer.forward(x, x, src_mask=causal_mask, tgt_mask=causal_mask,
                                        src_is_causal=True, tgt_is_causal=True), None
                                        


    