import random
import os
from typing import Dict, Callable, Iterable

import sparse
import torch
import numpy as np
import matplotlib.pyplot as plt

from .policy import RegretMatchingPolicy
from .policy_evaluation import policy_evaluation
from .policy_iterator_callbacks import ReturnGraphCallback, DeltaGraphCallback
from environment_generator import EnvironmentDataset
from driving_gridworld.actions import ACTIONS

import sys
# pytorch module loading workaround (pickle limitation)
sys.path.append('../model')

class PolicyIteration:
    def __init__(
            self,
            policy: "Policy",
            prob_trans_mat: np.array,
            reward_function: Callable,
            gamma: float = 0.99,
            theta: float = 1e-6,
            after_iteration_callbacks: Iterable[Callable] = [],
    ):
        self.policy = policy
        self.sparse_prob_trans_mat = sparse.COO(prob_trans_mat)
        self.reward_function = reward_function
        self.gamma = gamma
        self.theta = theta
        self.callbacks = after_iteration_callbacks
        self.prev_q = None

    def policy_evaluation(self) -> np.array:
        """
        This does policy evaluation in a standard way.
        """
        q = policy_evaluation(
            self.policy.pi,
            self.sparse_prob_trans_mat,
            self.reward_function,
            gamma=self.gamma,
            theta=self.theta,
            prev_q=self.prev_q,
        )
        self.prev_q = q
        return q

    def __call__(self):
        """
        Does a loop of policy evaluation followed by policy improvement.
        """
        q = self.policy_evaluation()
        pi_before = self.policy.pi
        self.policy.policy_improvement(q)
        self.delta = np.max(np.abs(pi_before - self.policy.pi))
        [callback.after_loop_callback(self, q) for callback in self.callbacks]

    def cleanup(self):
        [callback.cleanup_callback() for callback in self.callbacks]


class K_Of_N_PolicyIteration:
    """
    Policy iteration where the reward function is
    set to the k-of-n logic, whereby, for each
    policy evaluation -> policy improvement step:

    1. We calculate the q values for each reward
    function of the sampled n reward functions.

    2. We update the policy using the average of
    these lowest k q values, where the ordering is
    determined by the expected return given the
    current policy.

    We sample these reward models without replacement.
    As of writing this docstring, these are trained offline.

    See
    https://webdocs.cs.ualberta.ca/~bowling/papers/12nips-kofn.pdf
    for more details about the above process.
    """
    def __init__(
            self,
            k: int,
            n: int,
            policy: "Policy",
            prob_trans_mat: np.array,
            models_dir: "Filename",
            model_to_reward_function: Callable[[torch.nn.Module], Callable],
            gamma: float = 0.99,
            theta: float = 1e-6,
            after_iteration_callbacks: Iterable[Callable] = [],
            resample_models=False,
    ):
        """
        :params k, n: See class docstring.
        :param models_dir: The directory where all of the models are stored.
        :params prob_trans_mat, gamma, theta: See policy_evaluation.
        :param model_to_reward_function: A convenience function to transform
        a Pytorch model into something that's more like a tabular reward function
        (a mapping from (int, int) to float).
        :param after_iteration_callbacks: A list of callbacks to run after each iteration.
        Passes in the entire object to the callback.
        :param resample_models: Whether or not to resample the models in the untrained batch.
        Useful for iterating fast, should be False when generating truthful results.
        :attr model_dirs_to_sample: A list of all model directories which
        haven't been sampled yet; a convenient way to enforce that the
        reward models have been sampled without replacement.
        """
        self.k, self.n = k, n
        self.policy = policy
        self.prob_trans_mat = prob_trans_mat
        self.sparse_prob_trans_mat = sparse.COO(prob_trans_mat)
        self.gamma = gamma
        self.theta = theta
        self.model_to_reward_function = model_to_reward_function
        self.resample_models = resample_models
        self.callbacks = after_iteration_callbacks
        self.delta = float('inf')
        self.prev_q = None

        self.reward_functions = [model_to_reward_function(
            torch.load(os.path.join(models_dir, x)).eval())
            for x in sorted(os.listdir(models_dir))]

    def policy_evaluation(self) -> np.array:
        """
        This evaluates the current policy using the k-of-n metric,
        where the average of the lowest k (in terms of expected return
        given the start state) q values are :return:ed.
        """
        models = []
        qs = []
        for _ in range(self.n):
            reward_function = (random.choice(self.reward_functions)
                     if self.resample_models
                     else self.reward_functions.pop())
            qs.append(policy_evaluation(
                self.policy.pi,
                self.sparse_prob_trans_mat,
                reward_function,
                gamma=self.gamma,
                theta=self.theta,
                prev_q=self.prev_q,
            ))
        assert len(qs[0][0]) == len(ACTIONS)
        assert len(self.policy[0]) == len(ACTIONS)
        qs = sorted(qs, key=lambda q: np.sum(q[0] * self.policy[0]))
        self.prev_q = sum(qs[:self.k]) / self.k
        return self.prev_q

    def __call__(self):
        """
        Does a loop of policy evaluation followed by policy improvement.
        """
        q = self.policy_evaluation()
        pi_before = self.policy.pi
        self.policy.policy_improvement(q)
        self.delta = np.max(np.abs(pi_before - self.policy.pi))
        [callback.after_loop_callback(self, q) for callback in self.callbacks]

    def cleanup(self):
        [callback.cleanup_callback() for callback in self.callbacks]
