import pickle as pkl
from typing import Callable

import numpy as np

from .policy_evaluation import policy_evaluation


class CallbackInterface:
    def after_loop_callback(self, *args, **kwargs):
        raise NotImplementedError

    def cleanup_callback(self, *args, **kwargs):
        raise NotImplementedError


class GraphCallback:
    def __init__(
            self,
            title: str,
            xlabel: str,
            ylabel: str,
            ax: "Axes",
    ):
        """
        :param title: The title of the loss graph.
        :param xlabel: The title of the x axis of the loss graph.
        :param ylabel: The title of the y axis of the loss graph.
        :param ax: The axis for which the graph will be plotted.
        """
        self.title = title
        self.xlabel, self.ylabel = xlabel, ylabel
        self.ax = ax

    def cleanup_callback(self):
        """
        Creates the title, xlabel, and ylabel for the graph.
        """
        self.ax.set_title(self.title)
        self.ax.set_xlabel(self.xlabel)
        self.ax.set_ylabel(self.ylabel)


class ReturnGraphCallback(GraphCallback):
    def __init__(
            self,
            title: str,
            xlabel: str,
            ylabel: str,
            ax: "Axes",
            reward_function: Callable = None,
            state: int = 0,
    ):
        """
        :params title, xlabel, ylabel, ax: See GraphCallback's __init__ docstring.
        :param state: The state for which the return is being monitored.
        Usually this is the start state (encoded as state 0), which is why
        0 is the default.
        :param reward_function: The reward function to compute the expected
        return as. If None is specified, then the q function reurnred by the
        policy iterator is used.
        """
        super(ReturnGraphCallback, self).__init__(title, xlabel, ylabel, ax)
        self.returns = []
        self.state = state
        self.reward_function = reward_function

    def after_loop_callback(self,
                            policy_iterator: "PolicyIterator",
                            q: np.array,
                            ) -> None:
        """
        Computes the return desired and stores it.
        """
        if self.reward_function is not None:
            q = policy_evaluation(policy_iterator.policy.pi,
                                policy_iterator.sparse_prob_trans_mat,
                                self.reward_function,
                                gamma=policy_iterator.gamma,
                                theta=policy_iterator.theta)
        self.returns.append(np.sum(q[self.state] * policy_iterator.policy.pi[self.state]))

    def cleanup_callback(self):
        """
        Creates the graph on the axis specified.
        """
        super(ReturnGraphCallback, self).cleanup_callback()
        self.ax.plot(list(range(len(self.returns))), self.returns)
        pkl.dump(self.returns, open(f'../pkls/returns/{self.title}.pkl', 'wb'))


class DeltaGraphCallback(GraphCallback):
    def __init__(
            self,
            title: str,
            xlabel: str,
            ylabel: str,
            ax: "Axes",
    ):
        super(DeltaGraphCallback, self).__init__(title, xlabel, ylabel, ax)
        self.deltas = []

    def after_loop_callback(self,
                            policy_iterator: "PolicyIterator",
                            q: np.array,
                            ) -> None:
        """
        Stores the current delta.
        """
        self.deltas.append(policy_iterator.delta)

    def cleanup_callback(self):
        """
        Creates the graph on the axis specified.
        """
        super(DeltaGraphCallback, self).cleanup_callback()
        self.ax.plot(list(range(len(self.deltas))), self.deltas)
