from typing import Iterable, Callable

import numpy as np
import torch.nn as nn

from dataset import ShuffleDataset
from .minibatch_stats import RewardModelStats
from .model_interactor import RewardModelInteractor
from .model_interactor_callbacks import CallbackInterface


class RewardModelValidator(RewardModelInteractor):
    """
    Handles all reward model validation.
    """
    def __init__(
            self,
            model: nn.Module,
            dataset: ShuffleDataset,
            criterion: Callable,
            callbacks: Iterable[CallbackInterface],
    ):
        """
        :param model: The model to validate.
        :param dataset: The dataset which we'll use to validate the said model.
        :param criterion: The loss function which we would like to validate
        with respect to.
        :param callbacks: All callbacks which will be used in this pipeline.
        """
        super().__init__(model, dataset, criterion, callbacks)

    def minibatch(self, states: "Board", reward_arrays: np.array) -> None:
        """
        :param state: The board representation of the state.
        :param reward_array: Arrays which are essentially a mapping from
        action to reward for each state in states.
        """
        self.model.eval()
        super().minibatch(states, reward_arrays)

    def epoch(self) -> None:
        """
        Evaluates the model on one epoch of the data.
        """
        super().epoch()

    def cleanup(self) -> None:
        super().cleanup()
