import os
import pickle as pkl
from typing import Union, Callable, Iterable

import matplotlib.pyplot as plt

from config import Config
from .minibatch_stats import RewardModelStats


class CallbackInterface:
    def after_minibatch_callback(self, *args, **kwargs): raise NotImplementedError

    def after_epoch_callback(self, *args, **kwargs):
        raise NotImplementedError

    def cleanup_callback(self, *args, **kwargs):
        raise NotImplementedError


class AverageMetricGraphCallback:
    def __init__(
            self,
            title: str,
            xlabel: str,
            ylabel: str,
            ax: "Axes",
            model_name: str,
            stats_to_metric: Callable,
    ):
        """
        :param title: The title of the loss graph.
        :param xlabel: The title of the x axis of the loss graph.
        :param ylabel: The title of the y axis of the loss graph.
        :param ax: The axis for which the graph will be plotted.
        :param model_name: The name of the model.
        :param stats_to_metric: A function which handles the conversion
        of stats to a metric which can be graphed.
        """
        self.title = title
        self.xlabel, self.ylabel = xlabel, ylabel
        self.ax = ax
        self.model_name = model_name
        self.metric = []
        self.running_metric = 0
        self.num_items = 0
        self.stats_to_metric = stats_to_metric

    def after_epoch_reset(self):
        """
        Resets the running loss and the number of items seen.
        """
        self.running_metric = 0
        self.num_items = 0

    def reset(self):
        """
        Resets the loss, and everything we'd normally do after epoch.
        """
        self.metric = []
        self.after_epoch_reset()

    def after_minibatch_callback(self, minibatch_stats: Iterable[RewardModelStats]):
        """
        Keeps track of the running loss in the current epoch,
        along with how many items we've seen thus far.
        """
        for stats in minibatch_stats:
            self.running_metric += self.stats_to_metric(stats)
            self.num_items += 1

    def after_epoch_callback(self):
        """
        Adds the running_loss to our loss.
        """
        self.metric.append(self.running_metric / self.num_items)
        self.after_epoch_reset()

    def plot_progress(self):
        """
        Plots the loss progression over epochs.
        """
        self.ax.set_title(self.title)
        self.ax.set_xlabel(self.xlabel)
        self.ax.set_ylabel(self.ylabel)
        self.ax.scatter(list(range(len(self.metric))), self.metric)

    def cleanup_callback(self):
        """
        Plots the graph and saves it (if specified).
        """
        config = Config()
        self.plot_progress()
        pkl.dump(
            self.metric,
            open(
                os.path.join(
                    config.pkls_dir,
                    f'model_loss/{self.title}_{self.model_name}.pkl'), 'wb'))
        self.reset()

class LossGraphCallback(AverageMetricGraphCallback):
    def __init__(
            self,
            title: str,
            xlabel: str,
            ylabel: str,
            ax: "Axes",
            model_name: str,
    ):
        """
        :param title: The title of the loss graph.
        :param xlabel: The title of the x axis of the loss graph.
        :param ylabel: The title of the y axis of the loss graph.
        :param ax: The axis for which the graph will be plotted.
        :param model_name: The name of the model.
        """
        super().__init__(
            title,
            xlabel,
            ylabel,
            ax,
            model_name,
            lambda stats: stats.loss
        )

class TruncatedLossGraphCallback(LossGraphCallback):
    def __init__(self,
            title: str,
            xlabel: str,
            ylabel: str,
            ax: "Axes",
            model_name: str,
            truncation_amount: int,
    ):
        super(TruncatedLossGraphCallback, self).__init__(
            title,
            xlabel,
            ylabel,
            ax,
            model_name,
        )
        self.truncation_amount = truncation_amount

    def plot_progress(self):
        """
        Plots the loss progression over epochs.
        """
        tmp = self.metric
        self.metric = self.metric[self.truncation_amount:]
        super(TruncatedLossGraphCallback, self).plot_progress()
        self.metric = tmp

class FilterGraphCallback(AverageMetricGraphCallback):
    """
    Loss monitoring on a filtered portion of the incoming dataset, on a
    minibatch level.
    """
    def __init__(
            self,
            minibatch_filter: Callable,
            title: str,
            xlabel: str,
            ylabel: str,
            ax: "Axes",
            model_name: str,
            stats_to_metric: Callable,
    ):
        """
        :param minibatch_filter: A function which takes in a RewardModelStats
        object and outputs a boolean.
        See LossGraphCallback for an explanation of the other parameters.
        """
        self.minibatch_filter = minibatch_filter
        super(FilterGraphCallback, self).__init__(
            title,
            xlabel,
            ylabel,
            ax,
            model_name,
            stats_to_metric,
        )

    def after_minibatch_callback(self, minibatch_stats: Iterable[RewardModelStats]):
        minibatch_stats = list(filter(self.minibatch_filter, minibatch_stats))
        super(FilterGraphCallback, self).after_minibatch_callback(minibatch_stats)

class MonitorLossCallback(CallbackInterface):
    """
    Realtime text based monitoring of the loss of the model.
    Mainly used for quick debugging.
    """
    def __init__(self, name, print_every=1):
        self.totloss = self.epoch_sz = self.curr_epoch = 0
        self.print_every = print_every
        self.name = name


    def after_minibatch_callback(self, minibatch_stats: Iterable[RewardModelStats]):
        self.totloss += sum(x.loss for x in minibatch_stats)
        self.epoch_sz += len(minibatch_stats)


    def after_epoch_callback(self):
        if not self.curr_epoch % self.print_every:
            print(self.name, self.totloss / self.epoch_sz, flush=True)
        self.curr_epoch += 1
        self.totloss = self.epoch_sz = 0
