from typing import List, Dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from .recurrent.recurrent_model import RecurrentModel
from .rnn_cell import RNNCell


class RNNModel(RecurrentModel):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.pad_token_id = config.pad_token_id
        self.encoder_memory_len = config.encoder_memory_len

        self.recurrent_cell = RNNCell(config)

        self.lm_head = LMHead(config)
        # self.mlm_head = MLMHead(config)

        self.lm_criterion = nn.CrossEntropyLoss(ignore_index=self.pad_token_id)
        self.mlm_criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.eval_criteron = nn.CrossEntropyLoss(reduction="none", ignore_index=self.pad_token_id)

    def compute_outputs(self, recurrent_outputs, recurrent_inputs, training: bool = True) -> torch.Tensor:
        decoder_target_ids = recurrent_inputs["decoder_target_ids"]
        decoder_hidden_states = recurrent_outputs["decoder_hidden_states"]
        word_embedding = self.recurrent_cell.encoder_decoder.word_embedding.weight

        logits = self.lm_head(decoder_hidden_states, word_embedding)

        if training:
            loss = self.lm_criterion(logits.view(-1, logits.size(-1)), decoder_target_ids.view(-1))
            outputs = {"loss": loss}
        else:
            word_loss = self.eval_criteron(logits.view(-1, logits.size(-1)), decoder_target_ids.view(-1))
            word_loss = word_loss[word_loss != 0.0]
            outputs = {"word_loss": word_loss}

        return outputs

    def construct_memory(self, batch_size):
        device = next(self.parameters()).device
        
        decoder_memory = [
            {
                "past_hidden_states":
                    torch.randn(batch_size, self.config.decoder_cache_len, self.config.dim_model, device=device) * 0.02
            } for _ in range(self.config.num_decoder_layers)
        ]

        encoder_memory = {
            "encoder_cross_hidden": torch.randn(batch_size, self.encoder_memory_len, self.config.dim_model, device=device) * 0.02,
            "encoder_cross_mask": torch.ones(batch_size, self.encoder_memory_len, device=device).bool()
        }

        memory = encoder_memory, decoder_memory
        return memory

    def reset_memory(self, memory, memory_reset_signal: torch.Tensor) -> torch.Tensor:
        return memory


class LMHead(nn.Module):
    def __init__(self, config):
        super().__init__()

    def forward(self, hidden_states, word_embedding):
        # shape: (batch_size, seq_len, hidden_size)
        logits = F.linear(hidden_states, word_embedding)
        return logits


class MLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.dim_model, config.dim_model)
        self.reset_parameters()

    def forward(self, hidden_states, word_embedding):
        # shape: (batch_size, seq_len, hidden_size)
        hidden_states = self.dense(hidden_states)
        logits = F.linear(hidden_states, word_embedding)
        return logits

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.dense.weight.data)
        self.dense.bias.data.zero_()
