import torch
import torch.nn as nn
import math

from torch.utils.checkpoint import checkpoint

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1, max_len = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model = 768, nhead = 32, num_layers = 4, dim_feedforward = 2048, dropout=0.1):
        super(Transformer, self).__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.decoder = nn.ModuleList([decoder_layer for _ in range(num_layers)])
        self.final_layer = nn.Linear(d_model, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
        
    def forward(self, x, hidden = None, **kwargs):
        device = x.device
        seq_len = x.shape[1]
        x = self.embed(x) * math.sqrt(self.d_model)
        x = self.pos_encoder(x)

        mask = self.generate_square_subsequent_mask(seq_len).to(device)
        
        for layer in self.decoder:
            x = layer(x, x, tgt_mask=mask)
        
        return self.final_layer(x), (torch.zeros(self.d_model), )
    
class TransEnc(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, device, dropout = 0.1):
        super(TransEnc, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.nhead = nhead
        self.device = device
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model = d_model, nhead = nhead, batch_first = True)
        self.encoder = nn.ModuleList([encoder_layer for _ in range(num_encoder_layers)])

        self.output_layer = nn.Linear(d_model, 2)
    
    def forward(self, x, hidden = None, **kwargs):
        mask = (x == 2)
        x = self.embedding(x)
        x = self.pos_encoder(x)
        for layer in self.encoder:
            x = layer(x, src_key_padding_mask = mask)

        non_pad_mask = (~mask).unsqueeze(-1)
        x = x * non_pad_mask.float()
        sum_output = x.sum(dim = 1)
        non_pad_count = non_pad_mask.sum(dim = 1)
        x = sum_output / non_pad_count.clamp(min = 1)
        x = self.output_layer(x)
        return x, (torch.zeros(self.d_model), )