import copy
import itertools

import numpy as np
import scipy
from scipy.special import comb


def get_action_probs_from_Qs(Qs):
    """ For Q tables in N x ... x X x U where N the number of Q tables, compute the action probs X x U,
     i.e. max over last argument, and averaged over first argument """
    a = Qs.reshape((-1, Qs.shape[-1]))
    b = np.zeros_like(a)
    b[np.arange(len(a)), a.argmax(1)] = 1
    return b.reshape(Qs.shape).mean(0)


def get_new_action_probs_from_Qs(num_averages_yet, old_probs, Qs):
    """ For Q tables in N x ... x X x U where N the number of Q tables, compute the action probs X x U,
     i.e. max over last argument, and averaged over first argument """
    a = Qs.reshape((-1, Qs.shape[-1]))
    b = np.zeros_like(a)
    b[np.arange(len(a)), a.argmax(1)] = 1
    new_probs = b.reshape(Qs.shape).mean(0)
    return (old_probs * num_averages_yet + new_probs) / (num_averages_yet + 1)


def find_best_response(env, mus):
    Qs = []
    V_t_next = np.zeros((env.observation_space.n, ))
    for t in range(env.time_steps).__reversed__():
        P_t = env.get_P_high(t, mus[t])
        Q_t = env.get_R_high(t, mus[t]) + np.einsum('ijk,k->ji', P_t, V_t_next)
        V_t_next = np.max(Q_t, axis=-1)
        Qs.append(Q_t)

    Qs.reverse()
    out_Qs = np.array(Qs)
    return out_Qs


def find_best_response_k(env, probs_G, k):
    Qs = []
    V_t_next = np.zeros((env.observation_space.n, ))
    for t in range(env.time_steps).__reversed__():
        P_t = env.get_P_k(t, k, probs_G[t])
        Q_t = env.get_R_k(t, k, probs_G[t]) + np.einsum('ijk,k->ji', P_t, V_t_next)
        V_t_next = np.max(Q_t, axis=-1)
        Qs.append(Q_t)

    Qs.reverse()
    out_Qs = np.array(Qs)
    return out_Qs


def get_curr_mf(env, action_probs):
    mus = []
    curr_mf = env.mu_0
    mus.append(curr_mf)
    for t in range(env.time_steps):
        P_t = env.get_P_high(t, mus[t])
        xu = np.expand_dims(curr_mf, axis=(1,)) * action_probs[t]
        curr_mf = np.einsum('ijk,ji->k', P_t, xu)
        mus.append(curr_mf)

    return np.array(mus)


diophantine_cache = {}
Gmat_cache = {}


def get_diophantine_solutions(N, d):
    if (N, d) not in diophantine_cache:
        solutions = []
        for combination in itertools.product(*([list(range(N + 1))] * d)):
            if np.sum(combination) == N:
                solutions.append(np.array(combination))
        diophantine_cache[(N, d)] = np.array(solutions)
    return diophantine_cache[(N, d)]


def add_Gmats_ij(iminus, jminus, Gmat, all_Gmats, g, gp):
    if iminus + 1 < len(Gmat) and jminus + 1 < len(Gmat):
        max_val = min(gp[jminus] - np.sum(Gmat[:iminus, jminus]),
                      g[iminus] - np.sum(Gmat[iminus, :jminus]))
        if max_val < 0:
            raise NotImplementedError
        for gij in range(0, max_val + 1):
            Gmat[iminus, jminus] = gij
            add_Gmats_ij(iminus, jminus + 1, Gmat, all_Gmats, g, gp)
    elif iminus + 1 < len(Gmat) and jminus + 1 >= len(Gmat):
        Gmat[iminus, len(Gmat) - 1] = \
            g[iminus] - np.sum(Gmat[iminus, :-1])
        add_Gmats_ij(iminus + 1, 0, Gmat, all_Gmats, g, gp)
    elif iminus + 1 >= len(Gmat) and jminus + 1 < len(Gmat):
        Gmat[len(Gmat) - 1, jminus] = \
            gp[jminus] - np.sum(Gmat[:-1, jminus])
        add_Gmats_ij(iminus, jminus + 1, Gmat, all_Gmats, g, gp)
    elif iminus + 1 >= len(Gmat) and jminus + 1 >= len(Gmat):
        Gmat[len(Gmat) - 1, len(Gmat) - 1] = \
            g[len(Gmat) - 1] - np.sum(Gmat[len(Gmat) - 1, :-1])
        if Gmat[len(Gmat) - 1, len(Gmat) - 1] >= 0:
            all_Gmats.append(copy.deepcopy(Gmat))


def get_Gmats(g, gp, k, num_states):
    if (g, gp, k, num_states) not in Gmat_cache:
        all_gs = get_diophantine_solutions(k, num_states)
        all_Gmats = []
        Gmat = np.zeros((num_states, num_states), dtype=np.int64)
        add_Gmats_ij(0, 0, Gmat, all_Gmats, all_gs[g], all_gs[gp])
        Gmat_cache[(g, gp, k, num_states)] = all_Gmats
    return Gmat_cache[(g, gp, k, num_states)]


def multinomial_coeff(params):
    if len(params) == 1:
        return 1
    return comb(sum(params), params[-1]) * multinomial_coeff(params[:-1])


def get_curr_mf_k(env, action_probs, k, mus, action_probs_high):
    """ Initialize at time t=0 """
    num_states = action_probs.shape[1]
    mus_k = []
    probs_G_given_xs_k = []
    all_gs = get_diophantine_solutions(k, num_states)

    """ Initial probs of x and conditional G given x """
    curr_mf = env.mu_0
    curr_G_probs = np.array(num_states * [[scipy.stats.multinomial.pmf(g, k, p=env.mu_0) for g in all_gs]])
    mus_k.append(curr_mf)
    probs_G_given_xs_k.append(curr_G_probs)

    for t in range(env.time_steps):
        """ Compute mu at time t+1 """
        P_k = env.get_P_k_conditional(t, k, curr_G_probs)
        xu = np.expand_dims(curr_mf, axis=(1,)) * action_probs[t]
        new_mf = np.einsum('ijk,ji->k', P_k, xu)

        """ Compute conditional probs of G given x inductively """
        new_G_probs = []
        P_high = env.get_P_high(t, mus[t])
        for x in range(env.observation_space.n):
            new_G_probs.append(
                np.array([
                    (1 / new_mf[x] if new_mf[x] > 1e-10 else 0) * np.sum([
                        np.sum([
                            xu[xp, u] * curr_G_probs[xp][gp] * env.get_P_k_G(t, k, all_gs[gp] / k)[u, xp, x]
                            * np.sum([
                                np.prod([
                                    multinomial_coeff(Gmat[:, j]) * np.prod([
                                        np.sum([
                                            action_probs_high[t, j, up] * P_high[up, j, i]
                                            for up in range(env.action_space.n)
                                        ]) ** Gmat[i, j]
                                        for i in range(env.observation_space.n)
                                    ])
                                    for j in range(env.observation_space.n)
                                ])
                                for Gmat in get_Gmats(g, gp, k, num_states)
                            ])
                            for xp, u in itertools.product(range(env.observation_space.n), range(env.action_space.n))
                        ])
                        for gp in range(len(all_gs))
                    ])
                    for g in range(len(all_gs))
                ])
            )

        """ Save and update """
        mus_k.append(new_mf)
        probs_G_given_xs_k.append(new_G_probs)
        curr_mf = new_mf
        curr_G_probs = new_G_probs
        print(fr'mu_k k {k} Time {t}', flush=True)

    return np.array(mus_k)


def get_curr_probs_G_k(env, k, mus, action_probs_high):
    """ Initialize at time t=0 """
    num_states = action_probs_high.shape[1]
    probs_G_k = []
    all_gs = get_diophantine_solutions(k, num_states)

    """ Initial probs of G """
    curr_G_probs = np.array([scipy.stats.multinomial.pmf(g, k, p=env.mu_0) for g in all_gs])
    probs_G_k.append(curr_G_probs)

    for t in range(env.time_steps):
        """ Compute probs of G inductively """
        P_high = env.get_P_high(t, mus[t])
        new_G_probs = np.array([
            np.sum([
                curr_G_probs[gp]
                * np.sum([
                    np.prod([
                        multinomial_coeff(Gmat[:, j]) * np.prod([
                            np.sum([
                                action_probs_high[t, j, up] * P_high[up, j, i]
                                for up in range(env.action_space.n)
                            ]) ** Gmat[i, j]
                            for i in range(env.observation_space.n)
                        ])
                        for j in range(env.observation_space.n)
                    ])
                    for Gmat in get_Gmats(g, gp, k, num_states)
                ])
                for gp in range(len(all_gs))
            ])
            for g in range(len(all_gs))
        ])

        """ Save and update """
        probs_G_k.append(new_G_probs)
        curr_G_probs = new_G_probs
        print(fr'G_k k {k} Time {t}', flush=True)

    return np.array(probs_G_k)


def eval_curr_reward(env, action_probs, mus):
    Qs = []
    V_t_next = np.zeros((env.observation_space.n, ))
    for t in range(env.time_steps).__reversed__():
        P_t = env.get_P_high(t, mus[t])
        Q_t = env.get_R_high(t, mus[t]) \
              + np.einsum('ijk,k->ji', P_t, V_t_next)
        V_t_next = np.sum(action_probs[t] * Q_t, axis=-1)
        Qs.append(Q_t)

    Qs.reverse()
    out_Qs = np.array(Qs)
    return V_t_next, out_Qs


def get_softmax_action_probs_from_Qs(Qs, temperature=1.0):
    """ For Q tables in N x X x U where N the number of Q tables, compute the action probs X x U,
     i.e. max over last argument, and averaged over first argument """
    a = Qs.reshape((-1, Qs.shape[-1]))
    a = a - a.max(1, keepdims=True)
    b = np.exp(a / temperature)
    b = b / (np.sum(b, axis=1, keepdims=True))
    return b.reshape(Qs.shape).mean(0)


def get_softmax_new_action_probs_from_Qs(num_averages_yet, old_probs, Qs, temperature=1.0):
    """ For Q tables in N x X x U where N the number of Q tables, compute the action probs X x U,
     i.e. max over last argument, and averaged over first argument """
    a = Qs.reshape((-1, Qs.shape[-1]))
    a = a - a.max(1, keepdims=True)
    b = np.exp(a / temperature)
    b = b / (np.sum(b, axis=1, keepdims=True))
    new_probs = b.reshape(Qs.shape).mean(0)
    return (old_probs * num_averages_yet + new_probs) / (num_averages_yet + 1)
