from typing import Iterable

import torch.nn as nn
import numpy as np

from dataset import ShuffleDataset
from .model_interactor_callbacks import CallbackInterface
from .minibatch_stats import RewardModelStats

class RewardModelInteractor:
    """
    Unifying class between model training and validation.
    """
    def __init__(
            self,
            model: nn.Module,
            dataset: ShuffleDataset,
            criterion: "LossFunction",
            callbacks: Iterable[CallbackInterface] = [],
    ):
        """
        :param model: The model to validate.
        :param dataset: The dataset which we'll use to validate the said model.
        :param criterion: The loss function which we would like to validate
        with respect to.
        :param callbacks: All callbacks which will be used in this pipeline.
        """
        self.model = model
        self.dataset = dataset
        self.criterion = criterion
        self.callbacks = callbacks

    def minibatch(self,
                  states: "State",
                  reward_arrays: np.array,
                  ) -> None:
        """
        Interacts with the model on one minibatch.
        :param states: The states in our minibatch.
        :param reward_array: Arrays which are essentially a mapping from
        action to reward for each state in states.
        """
        outputs = self.model(states).reshape(-1)
        loss = self.criterion(outputs, reward_arrays)
        minibatch_stats = [RewardModelStats(
            (states[0][0][i], states[1][0][i]),
            reward_arrays[i],
            outputs[i],
            loss.detach().item(),
        ) for i in range(len(reward_arrays))]
        [callback.after_minibatch_callback(minibatch_stats)
            for callback in self.callbacks]
        return minibatch_stats

    def epoch(self) -> None:
        """
        Interacts with the model on one epoch of the data.
        """
        [self.minibatch(*inp) for inp in self.dataset]
        [callback.after_epoch_callback() for callback in self.callbacks]

    def cleanup(self) -> nn.Module:
        """
        Runs all of the callback cleanup functions.
        """
        [callback.cleanup_callback() for callback in self.callbacks]
