import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import numpy as np
from omegaconf import OmegaConf
from typing import List, Dict

from .modules.activations import ACT2FN
from .modules.normalizations import NORM2FN
from .recurrent.recurrent_model import RecurrentModel

from .compressive_transformer import CompressiveTransformerModel


class  CTransformerRecurrentModel(RecurrentModel):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.pad_token_id = config.pad_token_id

        self.recurrent_cell = CompressiveTransformerModel(config)
        self.lm_head = LMHead(config)

        self.lm_criterion = nn.CrossEntropyLoss(ignore_index=self.pad_token_id)
        self.eval_criteron = nn.CrossEntropyLoss(reduction="none", ignore_index=self.pad_token_id)

    def forward(self, batch_list, memory_hidden_states):
        rollout, memory_reset_signals = batch_list
        assert len(rollout) == 1

        step_inputs = rollout[0]
        hidden_states, memory_hidden_states, attn_loss = self.recurrent_cell(
            input_ids=step_inputs["source_ids"], all_mems=memory_hidden_states
        )
        logits = self.lm_head(hidden_states, self.recurrent_cell.word_emb.weight)

        loss = self.lm_criterion(logits.view(-1, logits.size(-1)), step_inputs["target_ids"].view(-1))
        outputs = {"loss": loss + attn_loss, "memory": memory_hidden_states, "attn_loss": attn_loss, "word_loss": loss}

        return outputs

    def compute_outputs(self, recurrent_outputs, recurrent_inputs, training: bool = True) -> torch.Tensor:
        target_ids = recurrent_inputs["target_ids"]
        hidden_states = recurrent_outputs["hidden_states"]
        word_embedding = self.recurrent_cell.word_emb.weight

        logits = self.lm_head(hidden_states, word_embedding)

        if training:
            assert False
            loss = self.lm_criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            outputs = {"loss": loss}
        else:
            word_loss = self.eval_criteron(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            word_loss = word_loss[word_loss != 0.0]
            outputs = {"word_loss": word_loss}

        return outputs

    def construct_memory(self, batch_size):
        return self.recurrent_cell.init_mems(batch_size)

    def reset_memory(self, memory, memory_reset_signal: torch.Tensor) -> torch.Tensor:
        "Transformer XL does not reset memory"
        return memory

class LMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        # self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        # self.reset_parameters()

    def forward(self, hidden_states, word_embedding):
        # shape: (batch_size, seq_len, hidden_size)
        # hidden_states = self.dense(hidden_states)
        logits = F.linear(hidden_states, word_embedding)
        return logits

    # def reset_parameters(self):
    #     nn.init.xavier_uniform_(self.dense.weight.data)
    #     self.dense.bias.data.zero_()