import torch
import copy
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from sgcrl.models.gpt import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer

class PositionalEncoding(nn.Module):
    # https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class Transformer(nn.Module):
    def __init__(self, num_classes, max_output_length, embedding_dim=128, dim=128, num_layers=4, nhead=4, dropout=0.2):
        super().__init__()

        # Parameters 
        self.num_classes = num_classes
        self.dim = embedding_dim
        self.max_output_length = max_output_length
        self.nhead = nhead
        self.num_layers = num_layers
        self.dim_feedforward = dim

        # Encoder
        self.embedding = nn.Embedding(num_classes + 3, dim) #+3 for eos, sos, padding
        self.pos_encoder = PositionalEncoding(d_model=self.dim)
        self.transformer_encoder = TransformerEncoder(
            encoder_layer = TransformerEncoderLayer(d_model=self.dim, nhead=self.nhead, dim_feedforward=self.dim_feedforward, dropout=dropout),
            num_layers=self.num_layers
        )

        # Decoder
        self.transformer_decoder = TransformerDecoder(
            decoder_layer=TransformerDecoderLayer(d_model=self.dim, nhead=self.nhead, dim_feedforward=self.dim_feedforward, dropout=dropout),
            num_layers=self.num_layers
        )

        self.linear = nn.Linear(self.dim, self.num_classes)

        self.init_weights()

    def init_weights(self):
        initrange = .1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def set_special_tokens(self, sos_token, eos_token, padding_token):
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.padding_token = padding_token

    def forward(self, x, y):
        encoded_x, padding_encoded_x = self.encode(x)
        output, attention_weights_lists = self.decode(y, encoded_x, padding_encoded_x)
        return output.permute(1, 2, 0)
    
    def _get_attention_weights(self):
        pass

    def encode(self, x):
        encoder_padding_mask = (x == self.padding_token).float()
        encoder_padding_mask = encoder_padding_mask.masked_fill(encoder_padding_mask == 1, float("-inf"))
        x = x.permute(1, 0)
        x = self.embedding(x) * np.sqrt(self.dim)    
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x, src_key_padding_mask=encoder_padding_mask)
        return x, encoder_padding_mask
    
    def decode(self, y, encoded_x, padding_encoded_x):
        tgt_padding_mask = (y == self.padding_token).float()
        tgt_padding_mask = tgt_padding_mask.masked_fill(tgt_padding_mask == 1, float("-inf"))

        y = y.permute(1, 0)
        y = self.embedding(y) * np.sqrt(self.dim)
        y = self.pos_encoder(y)

        y_mask = nn.Transformer.generate_square_subsequent_mask(y.shape[0]).to(y.device)

        output, attention_weights_lists = self.transformer_decoder(y, memory=encoded_x, tgt_mask=y_mask, tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=padding_encoded_x)
        output = self.linear(output)

        return output, attention_weights_lists
    
    def predict(self, x, token_history=[], n_plan_steps=False, argmax=False, plot_attention=False):
        self.max_output_length = 128
        output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() * self.sos_token
        B = x.shape[0]
        # Initialize output_tokens [B,T]
        if token_history:
            if len(token_history) >= self.max_output_length:
                token_history = token_history[-self.max_output_length+1:]
            output_tokens[:, :len(token_history)] = torch.tensor(token_history)
            x[:,2:] = torch.tensor([[token_history[0]]])
        encoded_x, padding_encoded_x = self.encode(x)

        planned_steps = 0
        if not n_plan_steps:
            n_plan_steps = self.max_output_length - len(token_history) - 1 # 127 planned + eos

        # output_tokens[:, 1] = x[:, 2]
        attention_weights_lists_by_steps = []
        for Sy in range(len(token_history)+1, self.max_output_length):
            y = output_tokens[:, :Sy]

            output, attention_weights_lists = self.decode(y, encoded_x, padding_encoded_x)
            attention_weights_lists_by_steps.append(attention_weights_lists)
            output = output.permute(1, 2, 0)

            # Argmax
            if argmax:
                output = torch.argmax(output, dim=1)
                output_tokens[:, Sy] = output[:, -1]
                # If we got a eos_token, return
                if x.shape[0] == 1 and output[:, -1].detach().cpu().numpy() == self.eos_token:
                    break
            # Sampling
            else:
                # Apply softmax to get probabilities
                T = .9 # T < 1: sharper, T = 1: standard, T > 1: smoother
                probabilities = F.softmax(output/T, dim=1)
                sampled_token = torch.multinomial(probabilities[:, :, -1], 1).squeeze(-1)
                output_tokens[:, Sy] = sampled_token

                # If we got a eos_token, return
                if x.shape[0] == 1 and sampled_token.detach().cpu().numpy() == self.eos_token:
                    break

            # If we reached the number of planned steps, return
            planned_steps += 1
            if planned_steps >= n_plan_steps:
                output_tokens[:, Sy+1:] = (torch.ones((x.shape[0], self.max_output_length-len(token_history)-planned_steps))).type_as(x).long() * self.eos_token
                break

        if plot_attention:
            return output_tokens[:, len(token_history):], attention_weights_lists_by_steps # only return the output tokens without the history
        else:
            return output_tokens[:, len(token_history):] # only return the output tokens without the history
    
    def predict_k_sample(self, x, token_history=[], n_plan_steps=False, k_samples=10, argmax=False, plot_attention=False):
        self.max_output_length = 128
        shortest_output_tokens = None
        shortest_length = 128
        for k in range(k_samples):
            output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() * self.sos_token
            B = x.shape[0]
            # Initialize output_tokens [B,T]
            if token_history:
                if len(token_history) >= self.max_output_length:
                    token_history = token_history[-self.max_output_length+1:]
                output_tokens[:, :len(token_history)] = torch.tensor(token_history)
                x[:,2:] = torch.tensor([[token_history[0]]])
            encoded_x, padding_encoded_x = self.encode(x)

            planned_steps = 0
            if not n_plan_steps:
                n_plan_steps = self.max_output_length - len(token_history) - 1

            # output_tokens[:, 1] = x[:, 2]
            attention_weights_lists_by_steps = []
            for Sy in range(len(token_history)+1, self.max_output_length):
                y = output_tokens[:, :Sy]

                output, attention_weights_lists = self.decode(y, encoded_x, padding_encoded_x)
                attention_weights_lists_by_steps.append(attention_weights_lists)
                output = output.permute(1, 2, 0)

                # Argmax
                if argmax:
                    output = torch.argmax(output, dim=1)
                    output_tokens[:, Sy] = output[:, -1]
                    # If we got a eos_token, return
                    if x.shape[0] == 1 and output[:, -1].detach().cpu().numpy() == self.eos_token:
                        break
                # Sampling
                else:
                    # Apply softmax to get probabilities
                    T = .9 # T < 1: sharper, T = 1: standard, T > 1: smoother
                    probabilities = F.softmax(output/T, dim=1)
                    sampled_token = torch.multinomial(probabilities[:, :, -1], 1).squeeze(-1)
                    output_tokens[:, Sy] = sampled_token

                    # If we got a eos_token, return
                    if x.shape[0] == 1 and sampled_token.detach().cpu().numpy() == self.eos_token:
                        break

                # If we reached the number of planned steps, return
                planned_steps += 1
                if planned_steps >= n_plan_steps:
                    output_tokens[:, Sy+1:] = (torch.ones((x.shape[0], self.max_output_length-len(token_history)-planned_steps))).type_as(x).long() * self.eos_token
                    break
                    
            
            # Find shortest path
            length = Sy + len(token_history)
            if (shortest_output_tokens is None) or (length <= shortest_length):
                shortest_length = length
                shortest_output_tokens = output_tokens

        if plot_attention:
            return shortest_output_tokens[:, len(token_history):], attention_weights_lists_by_steps # only return the output tokens without the history
        else:
            return shortest_output_tokens[:, len(token_history):] # only return the output tokens without the history
    
    def predict_beam(self, x, token_history=[], n_plan_steps=False, argmax=False, beam_width=10, plot_attention=False, choice='best_score'):
        self.max_output_length = 128
        B = x.shape[0]
        # Initialize output_tokens [B,T]
        if token_history:
            if len(token_history) >= self.max_output_length:
                token_history = token_history[-self.max_output_length+1:]
            output_tokens = torch.tensor(token_history).unsqueeze(0).repeat(B,1)
            x[:,2:] = torch.tensor([[token_history[0]]])
        else:
            output_tokens = (torch.ones((B, 1))).type_as(x).long() * self.sos_token
        encoded_x, padding_encoded_x = self.encode(x)

        # Initialize beams (each beam has a sequence and a cumulative score)
        beam = [(output_tokens, 0)]  # List of tuples (sequence, score), score is 0 meaning proba is 1 (conditional if histo)
        final_sequences = []

        planned_steps = 0
        if not n_plan_steps:
            n_plan_steps = self.max_output_length - len(token_history) - 1

        attention_weights_lists_by_steps = []

        # Loop through steps to generate tokens
        for Sy in tqdm(range(len(token_history)+1, self.max_output_length),desc='Planning beam search'):

            new_beam = []
            attention_weights_at_step = []

            for seq, score in beam:
                # Stop expanding this sequence if it already ended
                if seq[0, -1] == self.eos_token:
                    final_sequences.append((seq, score))
                    continue

                # Get the current output tokens (y)
                y = seq
                output, attention_weights_lists = self.decode(y, encoded_x, padding_encoded_x)
                attention_weights_at_step.append(attention_weights_lists)
                output = output.permute(1, 2, 0)  # Shape: (batch_size, vocab_size, seq_length)

                # Get the probabilities
                T = 0.9  # Temperature for softmax
                probabilities = F.softmax(output / T, dim=1)

                # For beam search, we take the top-k tokens (beam_width)
                top_k_probs, top_k_tokens = torch.topk(probabilities[:, :, -1], beam_width, dim=1)

                # Expand each sequence with the top-k tokens and their corresponding scores
                for i in range(beam_width):
                    new_seq = torch.cat([seq, top_k_tokens[:, i:i+1]], dim=1)
                    new_score = score + torch.log(top_k_probs[:, i])  # Add log-probability for cumulative score
                    new_beam.append((new_seq, new_score))

            # Sort the new beam by score and select the top `beam_width` sequences
            new_beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]
            beam = new_beam  # Update the beam with the new sequences

            attention_weights_lists_by_steps.append(attention_weights_at_step)

            # If we've reached the maximum number of planned steps or if all beams ended, break the loop
            planned_steps += 1
            if planned_steps >= n_plan_steps or all(seq[0, -1] == self.eos_token for seq, _ in beam):
                for i in range(beam_width):
                    seq, score = beam[i]
                    new_seq = torch.cat([seq, torch.tensor([[self.eos_token]])], dim=1)
                    new_beam.append((new_seq, score))
                break
        
        # Finalize by selecting the best sequence 
        # (highest score)
        if choice == 'best_score':
            if len(final_sequences) == 0:  # If no sequence ended, take the highest-scoring sequence from the beam
                final_sequences = beam
            best_sequence = max(final_sequences, key=lambda x: x[1])[0]  # Choose the sequence with the highest score
        elif choice == 'best_length':
            if len(final_sequences) == 0:  # If no sequence ended, take the highest-scoring sequence from the beam
                final_sequences = beam
            best_sequence = max(final_sequences, key=lambda x: len(x[0]))[0]  # Choose shortest sequence
        else:
            raise ValueError(f'Unknown sequence choice criteria: {choice}.')

        # If plot_attention is True, return the attention weights along with the output
        if plot_attention:
            return best_sequence[:, len(token_history):], attention_weights_lists_by_steps
        else:
            return best_sequence[:, len(token_history):]
