import warnings
from typing import Callable, Iterable, List, Union

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from dataset import ShuffleDataset
from .minibatch_stats import RewardModelStats, StateTuple
from .model_interactor import RewardModelInteractor
from .model_interactor_callbacks import CallbackInterface
from .model_validator import RewardModelValidator


class RewardModelTrainer(RewardModelInteractor):
    """
    Handles all of the reward model training.
    """
    def __init__(
            self,
            model: nn.Module,
            dataset: ShuffleDataset,
            criterion: "LossFunction",
            optimizer: optim.Optimizer,
            callbacks: Iterable[CallbackInterface] = [],
    ):
        """
        :params model, dataset, criterion, callbacks: See RewardModelValidator.
        :param optimizer: The optimizer which we'll use on our model.
        """
        super().__init__(model, dataset, criterion, callbacks)
        self.optimizer = optimizer

    def minibatch(self, states: torch.Tensor, reward_arrays: np.array) -> None:
        """
        Trains the model on one minibatch.
        :param states: The states in the minibatch.
        :param reward_array: Arrays which are essentially a mapping from
        action to reward for each state in states.
        """
        self.model.train()
        self.optimizer.zero_grad()
        outputs = self.model(states).reshape(-1)
        loss = self.criterion(outputs, reward_arrays)
        loss.backward()
        self.optimizer.step()

        outputs = outputs.cpu().detach().numpy()
        minibatch_stats = [RewardModelStats(
            StateTuple(states[0][0][i], states[0][1][i], states[1][0][i], states[1][1][i]),
            reward_arrays[i],
            outputs[i],
            loss.detach().item(),
        ) for i in range(len(reward_arrays))]
        [callback.after_minibatch_callback(minibatch_stats)
            for callback in self.callbacks]

    def epoch(self) -> None:
        """
        Trains the model on one epoch of the data.
        """
        super().epoch()

    def cleanup(self, save_model_filename: Union[str, None] = None) -> nn.Module:
        """
        Saves the model and runs all of the callback cleanup functions.
        :param save_model_filename: The filename to store the model at, if desired.
        :returns: The trained model.
        """
        if save_model_filename:
            torch.save(self.model, save_model_filename)
        super().cleanup()
