import numpy as np
import torch
import torch.nn as nn
from torchfly.metrics import Average
from torchfly.training.flymodel import FlyModel
from typing import Any, Dict, List, Tuple, Union
from apex.parallel import DistributedDataParallel, Reducer
from apex import amp

from .compressive_transformer_model import CTransformerRecurrentModel
from .recurrent.recurrent_flymodel import RecurrentFlyModel

amp.register_float_function(torch, 'einsum')


class CTransFlyModel(RecurrentFlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.recurrent_model = CTransformerRecurrentModel(config.model)
        self._perplexity = Average()

    def forward(self, batch):
        """
        This forward implementation is used for direct training with the official gradient checkpointing
        """
        outputs = self.recurrent_model(batch, self._working_memory)
        # Set Memory
        self.set_current_working_memory(outputs["memory"])
        return outputs

    def predict(self, batch):
        # for validation, one segment each time
        rollout, memory_reset_signals = batch
        assert len(rollout) == 1
        step_inputs = rollout[0]

        step_inputs = rollout[0]
        target_ids = step_inputs["target_ids"]
        hidden_states, current_memory, attn_loss = self.recurrent_model.recurrent_cell(
            input_ids=step_inputs["source_ids"], all_mems=self._working_memory
        )

        self.set_current_working_memory(current_memory)

        outputs = self.recurrent_model.compute_outputs(
            {"hidden_states": hidden_states}, {"target_ids": target_ids}, training=False
        )

        loss = outputs["word_loss"]
        self._perplexity(loss.tolist())

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        ppl = self._perplexity.get_metric(reset)
        metrics = {"perplexity": ppl}
        return metrics

    def reset(self, batch_size=None) -> None:
        memory = self.recurrent_model.construct_memory(batch_size)
        self.set_current_working_memory(memory)

    def detach_working_memory(self) -> Any:
        "Transformer XL detachs memory at each time step"
        pass
