import fnmatch
import itertools
import os
import string
from multiprocessing import Pool

import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler

from args_parser import generate_config_from_kw
from fast_marl_graphex import FastGraphexEnv
from utils import get_curr_mf, get_curr_probs_G_k, find_best_response_k, get_action_probs_from_Qs, get_curr_mf_k, \
    eval_curr_reward


def find(pattern, path):
    result = []
    for root, dirs, files in os.walk(path):
        for name in files:
            if fnmatch.fnmatch(name, pattern):
                result.append(os.path.join(root, name))
    return result


def run_once(env, degrees, action_probs_ks, action_probs, seed=None):
    """ Plot mean field of states """
    xs = env.reset(seed=seed)
    mf_state = [np.mean(xs == i) for i in range(env.observation_space.n)]
    mf_states = [mf_state]
    xss = [xs]
    uss = []
    val = 0

    env_degrees = (env.degrees * (env.degrees <= max(degrees))).astype(int)
    cum_ps_k = np.cumsum(np.array(action_probs_ks), axis=-1)
    cum_ps = np.cumsum(np.array(action_probs), axis=-1)

    for t in range(env.time_steps):
        # print(fr'Time {t}', flush=True)
        cum_ps_i = cum_ps_k[:, t][env_degrees, xs] * np.expand_dims(env.degrees <= max(degrees), axis=1) \
                   + np.expand_dims(cum_ps[t], axis=0)[0, xs] * np.expand_dims(env.degrees > max(degrees), axis=1)

        actions = np.zeros((env.num_agents,), dtype=int)
        uniform_samples = np.random.uniform(0, 1, size=env.num_agents)
        for idx in range(env.action_space.n):
            actions += idx * np.logical_and(
                uniform_samples >= (cum_ps_i[:, idx - 1] if idx - 1 >= 0 else 0.0),
                uniform_samples < cum_ps_i[:, idx])

        xs, rewards, done, info = env.step(actions)
        val += rewards

        xss.append(xs)
        uss.append(actions)

        mf_state = [np.mean(xs == i) for i in range(env.observation_space.n)]
        mf_states.append(mf_state)

    return val, np.array(mf_states), xss, env.degrees


def process_idx(rerun, config, game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, degrees, action_probs_ks, action_probs, idx):
    print(fr"Running trajectory {idx}", flush=True)

    if not rerun and os.path.exists(config['exp_dir'] + f"val_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy"):
        val = np.load(config['exp_dir'] + f"val_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy")
        mf_states = np.load(config['exp_dir'] + f"mf_states_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy")
        xss = np.load(config['exp_dir'] + f"xss_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy")
        run_degrees = np.load(config['exp_dir'] + f"run_degrees_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy")
    else:
        env: FastGraphexEnv = config['game'](**config, time_var=nu, graphex_cutoff=graphex_cutoff)
        val, mf_states, xss, run_degrees = run_once(env, degrees, action_probs_ks, action_probs, seed=hash((fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx)))
        np.save(config['exp_dir'] + f"val_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy", val)
        np.save(config['exp_dir'] + f"mf_states_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy", mf_states)
        np.save(config['exp_dir'] + f"xss_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy", xss)
        np.save(config['exp_dir'] + f"run_degrees_{game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, idx}.npy", run_degrees)

    return val, mf_states, xss, run_degrees


def evaluate_N(game, fp_iterations, temperature, variant, num_return_trials, graphex_cutoff, nu, degrees, degrees_to_plot, degree_cutoff, rerun=False):
    config = generate_config_from_kw(**{
        'game': game,
        'fp_iterations': fp_iterations,
        'temperature': temperature,
        'variant': variant,
    })
    env: FastGraphexEnv = config['game'](**config, time_var=1, graphex_cutoff=1)

    action_probs = np.load(config['exp_dir'] + f"action_probs.npy")
    best_response = np.load(config['exp_dir'] + f"best_response.npy")

    """ Compute the best responses and their mean fields for the low degree agents """
    print("Computing periphery ...", flush=True)
    mus = get_curr_mf(env, action_probs)
    v_curr_1 = np.vdot(env.mu_0, eval_curr_reward(env, action_probs, mus)[0])

    save_dir = config['exp_dir']
    if not rerun and os.path.exists(config['exp_dir'] + f"action_probs_ks.npy") and os.path.exists(config['exp_dir'] + f"mus_ks.npy"):
        action_probs_ks = np.load(config['exp_dir'] + f"action_probs_ks.npy")
        mus_ks = np.load(config['exp_dir'] + f"mus_ks.npy")
        v_curr_ks = np.load(config['exp_dir'] + f"v_curr_ks.npy")
    else:
        action_probs_ks = []
        mus_ks = []
        v_curr_ks = []
        for k in degrees:
            if k == 0:
                k = 1
            probs_G_k = get_curr_probs_G_k(env, k, mus, action_probs)
            Q_k = find_best_response_k(env, probs_G_k, k)
            action_probs_k = get_action_probs_from_Qs(np.array([Q_k]))
            action_probs_ks.append(action_probs_k)
            v_curr_ks.append(np.vdot(env.mu_0, Q_k.max(axis=-1)[0, :]))
            # if k <= max(degrees_to_plot):
            mus_k = get_curr_mf_k(env, action_probs_k, k, mus, action_probs)
            mus_ks.append(mus_k)
        action_probs_ks = np.array(action_probs_ks)
        mus_ks = np.array(mus_ks)
        np.save(config['exp_dir'] + f"action_probs_ks.npy", action_probs_ks)
        np.save(config['exp_dir'] + f"mus_ks.npy", mus_ks)
        np.save(config['exp_dir'] + f"v_curr_ks.npy", v_curr_ks)

    print("Running trajectories ...", flush=True)

    with Pool(processes=num_return_trials) as pool:
        res = pool.starmap(process_idx, zip([rerun] * num_return_trials,
                                            [config] * num_return_trials,
                                            [game] * num_return_trials,
                                            [fp_iterations] * num_return_trials,
                                            [temperature] * num_return_trials,
                                            [variant] * num_return_trials,
                                            [num_return_trials] * num_return_trials,
                                            [graphex_cutoff] * num_return_trials,
                                            [nu] * num_return_trials,
                                            [degrees] * num_return_trials,
                                            [action_probs_ks] * num_return_trials,
                                            [action_probs] * num_return_trials,
                                            range(num_return_trials)))

        vals = [r[0] for r in res]
        mf_statess = [r[1] for r in res]
        xsss = [np.array(r[2]) for r in res]
        run_degreess = [r[3] for r in res]

    diff_high_J = v_curr_1 - np.array([np.sum(val * (run_degrees > degree_cutoff)) / (np.sum((run_degrees > degree_cutoff)) + 1e-10) for val, run_degrees in zip(vals, run_degreess)])
    diff_high_mu = [np.mean(np.abs(
        mus.T - np.array([
            np.sum((xss == x) * (run_degrees > degree_cutoff), -1) / (np.sum((run_degrees > degree_cutoff), -1) + 1e-10)
            for x in range(env.observation_space.n)
        ])
    )) for xss, run_degrees in zip(xsss, run_degreess)]
    diff_k_Js = []
    diff_k_mus = []
    for k, v_curr_k, mus_k in zip(degrees, v_curr_ks, mus_ks):
        diff_k_J = v_curr_k - np.array([np.sum(val * (run_degrees == k)) / (np.sum((run_degrees == k)) + 1e-10) for val, run_degrees in zip(vals, run_degreess)])
        diff_k_mu = [np.mean(np.abs(
            mus_k.T - np.array([
                np.sum((xss == x) * (run_degrees == k), -1) / (np.sum((run_degrees == k), -1) + 1e-10) for x in range(env.observation_space.n)
            ])
        )) for xss, run_degrees in zip(xsss, run_degreess)]
        diff_k_Js.append(diff_k_J)
        diff_k_mus.append(diff_k_mu)
    Ns = [len(run_degrees) for run_degrees in run_degreess]

    print(fr"nu {nu} avg N: {np.mean(Ns)}", flush=True)

    return diff_high_J, diff_high_mu, diff_k_Js, diff_k_mus, Ns


def plot():

    """ Plot figures """
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.size": 24,
        "font.sans-serif": ["Helvetica"],
    })

    i = 1
    skip_n = 1
    rerun = False

    games = ['SIS', 'SIR', 'RS', ]
    graphex_cutoff = 100
    num_return_trials = 20
    nus = [10, 25, 50, 100, 250, 500, 750, ]

    for game in games:
        degree_cutoff = 8 if game != 'RS' else 6
        degrees_to_plot = [2, 4, 6, ]
        degrees = range(degree_cutoff + 1)

        diff_high_J_nu = []
        diff_high_mu_nu = []
        diff_k_J_nu = []
        diff_k_mu_nu = []
        Nss = []
        for nu in nus:
            diff_high_J, diff_high_mu, diff_k_J, diff_k_mu, Ns = evaluate_N(game, 5000, 50., "omd", num_return_trials, graphex_cutoff, nu, degrees, degrees_to_plot, degree_cutoff, rerun=False)
            diff_high_J_nu.append(diff_high_J)
            diff_high_mu_nu.append(diff_high_mu)
            diff_k_J_nu.append(diff_k_J)
            diff_k_mu_nu.append(diff_k_mu)
            Nss.append(Ns)

        """ Also plot the empirical mu deviation """
        clist = itertools.cycle(cycler(color='rbkgcmy'))
        linestyle_cycler = itertools.cycle(cycler('linestyle', ['-', '--', ':', '-.']))
        subplot = plt.subplot(1, len(games), i)
        # subplot.text(-0.01, 1.06, '(' + string.ascii_lowercase[i - 1] + ')', transform=subplot.transAxes, weight='bold')
        subplot.annotate('(' + string.ascii_lowercase[i - 1] + ')',
                         (1, 1),
                         xytext=(-36, -12),
                         xycoords='axes fraction',
                         textcoords='offset points',
                         fontweight='bold',
                         color='black',
                         alpha=0.7,
                         backgroundcolor='white',
                         ha='left', va='top')
        i += 1

        std_returns = 2 * np.std(np.abs(diff_high_mu_nu), axis=1) / np.sqrt(num_return_trials)
        mean_returns = np.mean(np.abs(diff_high_mu_nu), axis=1)

        color = clist.__next__()['color']
        linestyle = linestyle_cycler.__next__()['linestyle']
        subplot.plot(nus[::skip_n], mean_returns[::skip_n], linestyle, color=color, label="$\infty$")
        subplot.scatter(nus[::skip_n], mean_returns[::skip_n], color=color, label="__nolabel__")
        subplot.errorbar(nus[::skip_n], mean_returns[::skip_n], yerr=std_returns[::skip_n], color=color,
                         label="__nolabel__", alpha=0.85)

        for idx, k in enumerate(degrees):
            if k not in degrees_to_plot:
                continue
            std_returns = 2 * np.std(np.abs(diff_k_mu_nu)[:, idx, :], axis=1) / np.sqrt(num_return_trials)
            mean_returns = np.mean(np.abs(diff_k_mu_nu)[:, idx, :], axis=1)

            color = clist.__next__()['color']
            linestyle = linestyle_cycler.__next__()['linestyle']
            subplot.plot(nus[::skip_n], mean_returns[::skip_n], linestyle, color=color, label=fr"$k={k}$")
            subplot.scatter(nus[::skip_n], mean_returns[::skip_n], color=color, label="__nolabel__")
            subplot.errorbar(nus[::skip_n], mean_returns[::skip_n], yerr=std_returns[::skip_n], color=color,
                             label="__nolabel__", alpha=0.85)

        if i == 2:
            lgd1 = plt.legend(bbox_to_anchor=(0.4, 1.02, 1, 0.2), loc='lower left', ncol=5, fontsize="20")
        plt.grid('on')
        plt.xlabel(r'Graph size $\nu$', fontsize=22)
        if i == 2:
            plt.ylabel(r'$\Delta \mu^k, \Delta \mu^\infty$', fontsize=22)
        # plt.xlim([0, time_steps[-1]])
        plt.xlim([min(nus),  max(nus)])
        # plt.ylim(bottom=0)
        plt.yscale('log')

    """ Finalize plot """
    plt.gcf().set_size_inches(13, 3.5)
    # plt.tight_layout(w_pad=0.1)
    plt.savefig(f'./figures/propagation_of_chaos.pdf', bbox_inches='tight', transparent=True, pad_inches=0)
    plt.savefig(f'./figures/propagation_of_chaos.png', bbox_inches='tight', transparent=True, pad_inches=0)
    plt.show()


if __name__ == '__main__':
    plot()
