import numpy as np
import torch
import torch.nn as nn
from torchfly.metrics import Average
from torchfly.training.flymodel import FlyModel
from typing import Any, Dict, List, Tuple, Union

from .transformer_xl_model import TransformerXLModel
from .recurrent.recurrent_flymodel import RecurrentFlyModel


class TransformerXLFlyModel(RecurrentFlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.recurrent_model = TransformerXLModel(config.model)
        self._perplexity = Average()

    def forward(self, batch):
        """
        This forward implementation is used for direct training with the official gradient checkpointing
        """
        outputs = self.recurrent_model(batch, self._working_memory)
        # Set Memory
        self.set_current_working_memory(outputs["memory"])
        return outputs

    def predict(self, batch):
        # for validation, one segment each time
        rollout, memory_reset_signals = batch
        assert len(rollout) == 1
        step_inputs = rollout[0]

        step_inputs = rollout[0]
        target_ids = step_inputs["target_ids"]
        hidden_states, current_memory = self.recurrent_model.recurrent_cell(
            input_ids=step_inputs["source_ids"], mems=self._working_memory
        )

        self.set_current_working_memory(current_memory)

        outputs = self.recurrent_model.compute_outputs(hidden_states, {"target_ids": target_ids}, training=False)

        loss = outputs["word_loss"]
        self._perplexity(loss.tolist())

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        ppl = self._perplexity.get_metric(reset)
        metrics = {"perplexity": ppl}
        return metrics

    def reset(self, batch_size=None) -> None:
        memory = self.recurrent_model.construct_memory(batch_size)
        self.set_current_working_memory(memory)

    def detach_working_memory(self) -> Any:
        "Transformer XL detachs memory at each time step"
        pass
    