import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, Transformer
import pdb
import math

"""
Codes x To Codes y

This files contains helper models that from codes x predict codes y.

* CodeEncoder: Is a transformer model takes as input a sequence of length Tin (Tin, B, Din)
                and predicts a sequence of length Tout (Tout, B, Dout). This model is an encoder
                model which means that the final output is predicted by the output of a transformer
                via an MLP

* CodeTransformer: Does the same as the CodeEncoder model but using a full transformer encoder-decoder, 
                where decoding happens autoregressively
"""


class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float = 0.0, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("pos_embedding", pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
        return self.dropout(
            token_embedding + self.pos_embedding[: token_embedding.size(0), :]
        )


class CodeEncoder(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        d_model: int,
        nhead: int,
        d_hid: int,
        nlayers: int,
        dropout: float = 0.5,
        seq_len: int = 5000,
        batch_first: bool = False,
        norm_first: bool = False,
    ):
        super().__init__()
        self.model_type = "Transformer"
        self.d_model = d_model

        self.pos_encoder = PositionalEncoding(d_model, dropout, seq_len)

        encoder_layers = TransformerEncoderLayer(
            d_model,
            nhead,
            dim_feedforward=d_hid,
            dropout=dropout,
            batch_first=batch_first,
            norm_first=norm_first,
        )
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.has_linear_in = d_in != d_model
        if self.has_linear_in:
            self.linear_in = nn.Linear(d_in, d_model)
        self.linear_out = nn.Linear(d_model * seq_len, d_out)

        self._reset_parameters()

    def _reset_parameters(self):
        """Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src: torch.Tensor, src_mask: torch.Tensor = None) -> torch.Tensor:
        """
        Arguments:
            src: Tensor, shape ``[seq_len, batch_size]``
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ``[seq_len, batch_size, ntoken]``
        """
        if self.has_linear_in:
            src = self.linear_in(src)  # (T, B, D)
        src = self.pos_encoder(src)  # (T, B, D)
        output = self.transformer_encoder(src, src_mask)  # (T, B, D)

        output = torch.permute(output, (1, 0, 2))  # (B, T, D)
        output = output.flatten(start_dim=1)  # (B, T * D)
        output = self.linear_out(output)  # (B, Tout * Dout)
        return output


class CodeTransformer(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        seq_in: int,
        seq_out: int,
        d_model: int,
        d_ffn: int,
        nhead: int,
        n_enc_layers: int,
        n_dec_layers: int,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.model_type = "Transformer"

        # self.start_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.register_buffer("start_token", torch.zeros(1, 1, d_model))
        self.has_linear_in = d_in != d_model
        if self.has_linear_in:
            self.linear_in = nn.Linear(d_in, d_model)
        self.transformer = Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=n_enc_layers,
            num_decoder_layers=n_enc_layers,
            dim_feedforward=d_ffn,
            dropout=dropout,
        )
        self.generator = nn.Linear(d_model, d_out)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max(seq_in, seq_out))

        self._reset_parameters()

    def _reset_parameters(self):
        """Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None,
        src_padding_mask: torch.Tensor = None,
        tgt_padding_mask: torch.Tensor = None,
        memory_key_padding_mask: torch.Tensor = None,
    ):
        """
        Args:
            src: tensor of shape (seq_in, B, d_in)
            tgt: tensor of shape (seq_out, B, d_in)
        """
        # linear projection if d_in is not the same as d_model
        if self.has_linear_in:
            src = self.linear_in(src)  # (..., d_model)
            tgt = self.linear_in(tgt)  # (..., d_model)

        # src: positional encoding
        src_emb = self.pos_encoder(src)

        # tgt: add start token, remove last tgt token, add positional encoding
        tgt_wstart = torch.cat(
            (self.start_token.expand(-1, tgt.shape[1], -1), tgt[:-1]), dim=0
        )
        tgt_wstart_emb = self.pos_encoder(tgt_wstart)

        outs = self.transformer(
            src_emb,
            tgt_wstart_emb,
            src_mask,
            tgt_mask,
            None,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask,
        )
        # outs has shape (seq_out, B, d_model)

        outs = self.generator(outs)  # (seq_out, B, d_out)
        return outs

    def encode(self, src: torch.Tensor, src_mask: torch.Tensor):
        return self.transformer.encoder(self.pos_encoder(src), src_mask)

    def decode(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: torch.Tensor):
        return self.transformer.decoder(self.pos_encoder(tgt), memory, tgt_mask)

    def greedy_decode(self, src, src_mask, codebook, max_len):
        memory = self.encode(src, src_mask)
        B = src.shape[1]
        ys = self.start_token.clone().expand(-1, B, -1)
        ys_logits = []
        for i in range(max_len):
            tgt_mask = generate_square_subsequent_mask(ys.shape[0], src.device)
            out = self.decode(ys, memory, tgt_mask)
            out = out.transpose(0, 1)
            prob = self.generator(out[:, -1])
            ys_logits.append(prob)
            _, next_word = torch.max(prob, dim=1)
            new_code = codebook[next_word].view(1, B, codebook.shape[-1])
            if self.has_linear_in:
                new_code = self.linear_in(new_code)
            ys = torch.cat([ys, new_code], dim=0)
        ys_logits = torch.stack(ys_logits, dim=1)  # (B * Sout, Tout, vocab_size)
        return ys_logits


def generate_square_subsequent_mask(sz, device):
    return torch.triu(torch.full((sz, sz), float("-inf"), device=device), diagonal=1)


def create_mask(src_seq_len, tgt_seq_len, device):
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    src_padding_mask = None
    tgt_padding_mask = None
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
