import numpy as np
import torch
import torch.nn as nn
from torchfly.metrics import Average
from torchfly.training.flymodel import FlyModel
from torchfly.training.optimization.warmup_scheduler import WarmupCosineRestartWithDecay, FlyAnnealing
from torchfly.training.optimization.radam import RAdam
from torchfly.training.optimization.ranger import Ranger
from typing import Any, Dict, List, Tuple, Union
from overrides import overrides

from .rnn_model import RNNModel
from .recurrent.recurrent_flymodel import RecurrentFlyModel


class RNNFlymodel(RecurrentFlyModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.recurrent_model = RNNModel(config.model)
        self._perplexity = Average()

    def predict(self, batch):
        # for validation, one segment each time
        rollout, memory_reset_signals = batch
        assert len(rollout) == 1
        step_inputs = rollout[0]
        memory_reset_signal = memory_reset_signals[0]

        hidden_states, new_memory = self.recurrent_model.recurrent_cell(step_inputs, self._working_memory)

        outputs = self.recurrent_model.compute_outputs(hidden_states, step_inputs, training=False)
        loss = outputs["word_loss"]
        self._perplexity(loss.tolist())

        # Set Memory
        current_memory = self.recurrent_model.reset_memory(new_memory, memory_reset_signal)
        self.set_current_working_memory(current_memory)

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        ppl = self._perplexity.get_metric(reset)
        metrics = {"perplexity": ppl}
        return metrics

    def reset(self, batch_size=None) -> None:
        memory = self.recurrent_model.construct_memory(batch_size)
        self.set_current_working_memory(memory)

    @overrides
    def get_zero_memory_grad(self, memory):
        """ We only want gradient for working_memory"""
        encoder_memory, decoder_memory = memory
        return torch.zeros_like(encoder_memory["encoder_cross_hidden"])

    @overrides
    def set_requires_grad(self, memory):
        """ We only want gradient for working_memory"""
        encoder_memory, decoder_memory = memory

        if not encoder_memory["encoder_cross_hidden"].requires_grad:
            encoder_memory["encoder_cross_hidden"].requires_grad = True
        # retain the gradient for non-leaf variables
        encoder_memory["encoder_cross_hidden"].retain_grad()

    @overrides
    def detach_working_memory(self) -> Any:
        encoder_memory, decoder_memory = self._working_memory
        encoder_memory["encoder_cross_hidden"] = encoder_memory["encoder_cross_hidden"].detach()
        self._working_memory = encoder_memory, decoder_memory

    @overrides
    def get_memory_grad(self, memory):
        encoder_memory, decoder_memory = memory
        return encoder_memory["encoder_cross_hidden"].grad

    @overrides
    def backward_memory_grad(self, memory, memory_grad):
        encoder_memory, decoder_memory = memory
        encoder_memory["encoder_cross_hidden"].backward(memory_grad, retain_graph=True)

    def configure_optimizers(self, total_num_update_steps) -> [List, List]:
        optimizer_grouped_parameters = self.get_optimizer_parameters()
        lr = self.config.training.optimization.learning_rate
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.99), eps=1e-6)

        warmup_steps = self.config.training.optimization.warmup.warmup_steps
        cycle_steps = self.config.training.optimization.warmup.cycle_steps

        scheduler = FlyAnnealing(
            optimizer, warmup_steps, total_num_update_steps, total_num_update_steps, power_decay=1, cycle_mult=1
        )

        return [optimizer], [scheduler]