import fnmatch
import os
import string

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import pylab as pl

from args_parser import generate_config_from_kw
from fast_marl_graphex import FastGraphexEnv
from utils import eval_curr_reward
from utils import get_curr_mf, get_curr_probs_G_k, find_best_response_k, get_action_probs_from_Qs, get_curr_mf_k


def find(pattern, path):
    result = []
    for root, dirs, files in os.walk(path):
        for name in files:
            if fnmatch.fnmatch(name, pattern):
                result.append(os.path.join(root, name))
    return result


def run_once(env, degrees, action_probs_ks, action_probs):
    """ Plot mean field of states """
    xs = env.reset()
    mf_state = [np.mean(xs == i) for i in range(env.observation_space.n)]
    mf_states = [mf_state]
    xss = [xs]
    uss = []
    val = 0

    env_degrees = (env.degrees * (env.degrees <= max(degrees))).astype(int)
    cum_ps_k = np.cumsum(np.array(action_probs_ks), axis=-1)
    cum_ps = np.cumsum(np.array(action_probs), axis=-1)

    for t in range(env.time_steps):
        print(fr'Time {t}', flush=True)
        cum_ps_i = cum_ps_k[:, t][env_degrees, xs] * np.expand_dims(env.degrees <= max(degrees), axis=1) \
                   + np.expand_dims(cum_ps[t], axis=0)[0, xs] * np.expand_dims(env.degrees > max(degrees), axis=1)

        actions = np.zeros((env.num_agents,), dtype=int)
        uniform_samples = np.random.uniform(0, 1, size=env.num_agents)
        for idx in range(env.action_space.n):
            actions += idx * np.logical_and(
                uniform_samples >= (cum_ps_i[:, idx - 1] if idx - 1 >= 0 else 0.0),
                uniform_samples < cum_ps_i[:, idx])

        xs, rewards, done, info = env.step(actions)
        val += rewards

        xss.append(xs)
        uss.append(actions)

        mf_state = [np.mean(xs == i) for i in range(env.observation_space.n)]
        mf_states.append(mf_state)

    return val, np.array(mf_states), xss, env.degrees


def evaluate_N(dataset, game, fp_iterations, temperature, variant, num_return_trials, degrees, degrees_to_plot, degree_cutoff, rerun=False):
    config = generate_config_from_kw(**{
        'game': game,
        'fp_iterations': fp_iterations,
        'temperature': temperature,
        'variant': variant,
    })
    env: FastGraphexEnv = config['game'](**config, time_var=1, graphex_cutoff=1)

    action_probs = np.load(config['exp_dir'] + f"action_probs.npy")
    best_response = np.load(config['exp_dir'] + f"best_response.npy")

    """ Compute the best responses and their mean fields for the low degree agents """
    print("Computing periphery ...", flush=True)
    mus = get_curr_mf(env, action_probs)
    v_curr_1 = np.vdot(env.mu_0, eval_curr_reward(env, action_probs, mus)[0])

    save_dir = config['exp_dir']
    if not rerun and os.path.exists(config['exp_dir'] + f"action_probs_ks.npy") and os.path.exists(config['exp_dir'] + f"mus_ks.npy"):
        action_probs_ks = np.load(config['exp_dir'] + f"action_probs_ks.npy")
        mus_ks = np.load(config['exp_dir'] + f"mus_ks.npy")
        v_curr_ks = np.load(config['exp_dir'] + f"v_curr_ks.npy")
    else:
        action_probs_ks = []
        mus_ks = []
        v_curr_ks = []
        for k in degrees:
            if k == 0:
                k = 1
            probs_G_k = get_curr_probs_G_k(env, k, mus, action_probs)
            Q_k = find_best_response_k(env, probs_G_k, k)
            action_probs_k = get_action_probs_from_Qs(np.array([Q_k]))
            action_probs_ks.append(action_probs_k)
            v_curr_ks.append(np.vdot(env.mu_0, Q_k.max(axis=-1)[0, :]))
            # if k <= max(degrees_to_plot):
            mus_k = get_curr_mf_k(env, action_probs_k, k, mus, action_probs)
            mus_ks.append(mus_k)
        action_probs_ks = np.array(action_probs_ks)
        mus_ks = np.array(mus_ks)
        np.save(config['exp_dir'] + f"action_probs_ks.npy", action_probs_ks)
        np.save(config['exp_dir'] + f"mus_ks.npy", mus_ks)
        np.save(config['exp_dir'] + f"v_curr_ks.npy", v_curr_ks)

    print("Running trajectories ...", flush=True)
    if not rerun and os.path.exists(config['exp_dir'] + f"vals_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz"):
        vals = list(np.load(config['exp_dir'] + f"vals_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz").values())
        mf_statess = list(np.load(config['exp_dir'] + f"mf_statess_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz").values())
        xsss = list(np.load(config['exp_dir'] + f"xsss_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz").values())
        run_degreess = list(np.load(config['exp_dir'] + f"run_degreess_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz").values())
    else:
        dataset_config = generate_config_from_kw(**{
            'game': game,
            'fp_iterations': fp_iterations,
            'temperature': temperature,
            'variant': variant,
            'dataset': dataset,
        })
        env: FastGraphexEnv = dataset_config['game'](**dataset_config, time_var=1, graphex_cutoff=1)

        vals = []
        mf_statess = []
        xsss = []
        run_degreess = []
        for _ in range(num_return_trials):
            val, mf_states, xss, run_degrees = run_once(env, degrees, action_probs_ks, action_probs)
            print(f'{len(run_degrees)}: {vals}', flush=True)
            vals.append(val)
            mf_statess.append(mf_states)
            xsss.append(np.array(xss))
            run_degreess.append(run_degrees)
        np.savez(config['exp_dir'] + f"vals_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz", *vals)
        np.savez(config['exp_dir'] + f"mf_statess_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz", *mf_statess)
        np.savez(config['exp_dir'] + f"xsss_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz", *xsss)
        np.savez(config['exp_dir'] + f"run_degreess_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz", *run_degreess)

    return vals, mf_statess, xsss, run_degreess


def plot():

    """ Plot figures """
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.size": 24,
        "font.sans-serif": ["Helvetica"],
    })

    i = 1
    skip_n = 1
    rerun = False

    games = ['SIS', 'SIR', 'RS',]
    datasets = ["prosper-loans", "petster-friendships-dog", "soc-pokec-relationships",]
    num_return_trials = 1

    for game, dataset in zip(games, datasets):
        degree_cutoff = 8 if game != 'RS' else 6
        degrees_to_plot = [2, 3, 4, ]
        degrees = range(degree_cutoff + 1)

        vals, mf_statess, xsss, run_degreess = evaluate_N(dataset, game, 5000, 50., "omd", num_return_trials, degrees, degrees_to_plot, degree_cutoff, rerun=False)
        val = vals[0]
        mf_states = mf_statess[0]
        xss = xsss[0]
        run_degrees = run_degreess[0]

        config = generate_config_from_kw(**{
            'game': game,
            'fp_iterations': 5000,
            'temperature': 50.,
            'variant': "omd",
        })
        env: FastGraphexEnv = config['game'](**config, time_var=1, graphex_cutoff=1)

        action_probs = np.load(config['exp_dir'] + f"action_probs.npy")
        best_response = np.load(config['exp_dir'] + f"best_response.npy")

        """ Compute the best responses and their mean fields for the low degree agents """
        print("Computing periphery ...", flush=True)
        mus = get_curr_mf(env, action_probs)
        v_curr_1 = np.vdot(env.mu_0, eval_curr_reward(env, action_probs, mus)[0])

        save_dir = config['exp_dir']
        if not rerun and os.path.exists(config['exp_dir'] + f"action_probs_ks.npy") and os.path.exists(
                config['exp_dir'] + f"mus_ks.npy"):
            action_probs_ks = np.load(config['exp_dir'] + f"action_probs_ks.npy")
            mus_ks = np.load(config['exp_dir'] + f"mus_ks.npy")
            v_curr_ks = np.load(config['exp_dir'] + f"v_curr_ks.npy")
        else:
            action_probs_ks = []
            mus_ks = []
            v_curr_ks = []
            for k in degrees:
                if k == 0:
                    k = 1
                probs_G_k = get_curr_probs_G_k(env, k, mus, action_probs)
                Q_k = find_best_response_k(env, probs_G_k, k)
                action_probs_k = get_action_probs_from_Qs(np.array([Q_k]))
                action_probs_ks.append(action_probs_k)
                v_curr_ks.append(np.vdot(env.mu_0, Q_k.max(axis=-1)[0, :]))
                # if k <= max(degrees_to_plot):
                mus_k = get_curr_mf_k(env, action_probs_k, k, mus, action_probs)
                mus_ks.append(mus_k)
            action_probs_ks = np.array(action_probs_ks)
            mus_ks = np.array(mus_ks)
            np.save(config['exp_dir'] + f"action_probs_ks.npy", action_probs_ks)
            np.save(config['exp_dir'] + f"mus_ks.npy", mus_ks)
            np.save(config['exp_dir'] + f"v_curr_ks.npy", v_curr_ks)

        for x in range(1, 2 if game != 'SIR' else 3):
            subplot = plt.subplot(1, 4, i)
            subplot.annotate('(' + string.ascii_lowercase[i - 1] + ')',
                             (1, 1),
                             xytext=(-36, -12),
                             xycoords='axes fraction',
                             textcoords='offset points',
                             fontweight='bold',
                             color='black',
                             alpha=0.7,
                             backgroundcolor='white',
                             ha='left', va='top')
            # subplot.text(-0.01, 1.06, '(' + string.ascii_lowercase[i - 1] + ')', transform=subplot.transAxes, weight='bold')
            i += 1

            if game == "RS":
                x = 0
            cmap = pl.cm.plasma_r
            colors = cmap(np.linspace(0, 1, len(degrees) + 1))

            """ Plot empirical MFs """
            print("Plotting empirical MFs", flush=True)
            emp_high_mf = np.sum((np.array(xss) == x) * np.expand_dims(run_degrees > degree_cutoff, axis=0),
                                 axis=1) / np.sum(run_degrees > degree_cutoff)
            plt.plot(range(env.time_steps + 1), emp_high_mf, color=colors[-1], label=r"$k>k_{\max}$",
                     linewidth=2)
            for k, color in zip(degrees, colors):
                if k not in degrees_to_plot:
                    continue
                emp_k_mfs = np.sum((np.array(xss) == x) * np.expand_dims(run_degrees == k, axis=0), axis=1) / np.sum(
                    run_degrees == k)
                plt.plot(range(env.time_steps + 1), emp_k_mfs, color=color, label=f"$k={k}$", linewidth=2)

            """ Plot MFs """
            plt.plot(range(env.time_steps + 1), mus[:, x], color=colors[-1], label=f"MF core", linestyle='--', linewidth=2)

            print("Plotting periphery MF predictions", flush=True)
            for k, color, mus_k in zip(degrees, colors, mus_ks):
                if k not in degrees_to_plot:
                    continue
                plt.plot(range(env.time_steps + 1), mus_k[:, x], color=color, label=f"MF $k={k}$", linewidth=1,
                         linestyle='--', )

            plt.ylim([0, 1])
            if i == 2:
                plt.ylabel("$\mu_t(x)$")
            else:
                plt.yticks([], [])
            plt.xlim([0, env.time_steps])
            plt.xlabel(fr"Time $t$")
            if i == 2:
                lgd1 = plt.legend(bbox_to_anchor=(0.2, 1.02, 1, 0.2), loc='lower left', ncol=4, fontsize="20")

    """ Finalize plot """
    plt.gcf().set_size_inches(13, 3)
    plt.tight_layout(w_pad=0.1)
    plt.savefig(f'./figures/trajectories.pdf', bbox_extra_artists=(lgd1,), bbox_inches='tight', transparent=True, pad_inches=0)
    plt.savefig(f'./figures/trajectories.png', bbox_extra_artists=(lgd1,), bbox_inches='tight', transparent=True, pad_inches=0)
    plt.show()


if __name__ == '__main__':
    plot()
