import warnings

import sparse
import numpy as np

from typing import Callable, Union
from itertools import product


def compress(dense_dense_sparse: np.array):
    """
    Compresses a matrix who has the last dimension as sparse
    using the dictionary of keys representation.
    :param dense_dense_sparse: A numpy array that is dense
    in the first two dimensions, and sparse in the last one.
    """
    res = {}
    for i in range(len(dense_dense_sparse)):
        res[i] = {}
        for j in range(len(dense_dense_sparse[0])):
            nonzero_entries = []
            for k, x in enumerate(dense_dense_sparse[i][j]):
                if x > 0:
                    nonzero_entries.append((k, x))
            res[i][j] = nonzero_entries
    return res

def construct_reward_matrix(reward_matrix_shape, reward_function):
    reward_mat = np.zeros(reward_matrix_shape)
    for s in range(reward_matrix_shape[0]):
        for a in range(reward_matrix_shape[1]):
            reward_mat[s][a] = reward_function(s, a)
    return reward_mat

def policy_evaluation(
        pi: np.array,
        sparse_prob_trans_mat: Union[sparse.COO, np.array],
        reward_fcn: Callable[[int, int], float],
        gamma: float = 0.99,
        theta: float = 10e-6,
        prev_q: np.array = None,
) -> np.array:
    """
    Does policy evaluation using the MDP dynamics, which is represented
    as a probability transition matrix and a reward function.
    :param pi: A matrix where pi[s][a] is the probability of taking action
    a given we are in state s.
    :param sparse_prob_trans_mat: A probability transition matrix where
    prob_trans_mat[s][a][s_prime] is the probability that we will transition to
    s_prime given that we were in state s doing action a.
    Typically sparse; will build a NumPy array into sparse, but this is costly
    to do many times, so a warning will be emitted.
    :param reward_fcn: A mapping from (state, action) to reward.
    :param gamma: The decay rate for the return.
    :param theta: A tolerance variable which tells us when our current evaluated
    policy is "good enough"; this probably can just be left at 10e-6.
    """
    q = np.zeros_like(pi) if prev_q is None else prev_q

    if isinstance(sparse_prob_trans_mat, np.ndarray):
        sparse_prob_trans_mat = sparse.COO(sparse_prob_trans_mat)
        warnings.warn("Constructing sparse_prob_trans_mat as a \
sparse matrix; typically costly; should call sparse.COO(prob_trans_mat) \
before passing the prob_trans_mat in.")
    reward_mat = construct_reward_matrix(pi.shape, reward_fcn)

    delta = float('inf')
    while delta > theta:
        new_q = np.copy(reward_mat)
        v = np.sum(pi * q, axis=1)
        new_q += gamma * np.dot(sparse_prob_trans_mat, v)
        delta = np.max(np.abs(new_q - q))
        q = new_q
    return q
