import fnmatch
import os

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import gamma, factorial

from args_parser import generate_config_from_kw
from fast_marl_graphex import FastGraphexEnv
from utils import eval_curr_reward
from utils import get_curr_mf, get_curr_probs_G_k, find_best_response_k, get_action_probs_from_Qs, get_curr_mf_k


def find(pattern, path):
    result = []
    for root, dirs, files in os.walk(path):
        for name in files:
            if fnmatch.fnmatch(name, pattern):
                result.append(os.path.join(root, name))
    return result


def run_once(env, degrees, action_probs_ks, action_probs):
    """ Plot mean field of states """
    xs = env.reset()
    mf_state = [np.mean(xs == i) for i in range(env.observation_space.n)]
    mf_states = [mf_state]
    xss = [xs]
    uss = []
    val = 0

    env_degrees = (env.degrees * (env.degrees <= max(degrees))).astype(int)
    cum_ps_k = np.cumsum(np.array(action_probs_ks), axis=-1)
    cum_ps = np.cumsum(np.array(action_probs), axis=-1)

    for t in range(env.time_steps):
        print(fr'Time {t}', flush=True)
        cum_ps_i = cum_ps_k[:, t][env_degrees, xs] * np.expand_dims(env.degrees <= max(degrees), axis=1) \
                   + np.expand_dims(cum_ps[t], axis=0)[0, xs] * np.expand_dims(env.degrees > max(degrees), axis=1)

        actions = np.zeros((env.num_agents,), dtype=int)
        uniform_samples = np.random.uniform(0, 1, size=env.num_agents)
        for idx in range(env.action_space.n):
            actions += idx * np.logical_and(
                uniform_samples >= (cum_ps_i[:, idx - 1] if idx - 1 >= 0 else 0.0),
                uniform_samples < cum_ps_i[:, idx])

        xs, rewards, done, info = env.step(actions)
        val += rewards

        xss.append(xs)
        uss.append(actions)

        mf_state = [np.mean(xs == i) for i in range(env.observation_space.n)]
        mf_states.append(mf_state)

    return val, np.array(mf_states), xss, env.degrees


def evaluate_N(dataset, sigma, game, fp_iterations, temperature, variant, num_return_trials, degrees, degrees_to_plot, degree_cutoff, rerun=False):
    config = generate_config_from_kw(**{
        'game': game,
        'fp_iterations': fp_iterations,
        'temperature': temperature,
        'variant': variant,
    })
    env: FastGraphexEnv = config['game'](**config, time_var=1, graphex_cutoff=1)

    action_probs = np.load(config['exp_dir'] + f"action_probs.npy")
    best_response = np.load(config['exp_dir'] + f"best_response.npy")

    """ Compute the best responses and their mean fields for the low degree agents """
    # print("Computing periphery ...")
    mus = get_curr_mf(env, action_probs)
    v_curr_1 = np.vdot(env.mu_0, eval_curr_reward(env, action_probs, mus)[0])

    save_dir = config['exp_dir']
    if not rerun and os.path.exists(config['exp_dir'] + f"action_probs_ks.npy") and os.path.exists(config['exp_dir'] + f"mus_ks.npy"):
        action_probs_ks = np.load(config['exp_dir'] + f"action_probs_ks.npy")
        mus_ks = np.load(config['exp_dir'] + f"mus_ks.npy")
        v_curr_ks = np.load(config['exp_dir'] + f"v_curr_ks.npy")
    else:
        action_probs_ks = []
        mus_ks = []
        v_curr_ks = []
        for k in degrees:
            if k == 0:
                k = 1
            probs_G_k = get_curr_probs_G_k(env, k, mus, action_probs)
            Q_k = find_best_response_k(env, probs_G_k, k)
            action_probs_k = get_action_probs_from_Qs(np.array([Q_k]))
            action_probs_ks.append(action_probs_k)
            v_curr_ks.append(np.vdot(env.mu_0, Q_k.max(axis=-1)[0, :]))
            # if k <= max(degrees_to_plot):
            mus_k = get_curr_mf_k(env, action_probs_k, k, mus, action_probs)
            mus_ks.append(mus_k)
        action_probs_ks = np.array(action_probs_ks)
        mus_ks = np.array(mus_ks)
        np.save(config['exp_dir'] + f"action_probs_ks.npy", action_probs_ks)
        np.save(config['exp_dir'] + f"mus_ks.npy", mus_ks)
        np.save(config['exp_dir'] + f"v_curr_ks.npy", v_curr_ks)

    # print("Running trajectories ...")
    if not rerun and os.path.exists(config['exp_dir'] + f"vals_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz"):
        vals = list(np.load(config['exp_dir'] + f"vals_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz").values())
        mf_statess = list(np.load(config['exp_dir'] + f"mf_statess_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz").values())
        xsss = list(np.load(config['exp_dir'] + f"xsss_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz").values())
        run_degreess = list(np.load(config['exp_dir'] + f"run_degreess_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz").values())
    else:
        dataset_config = generate_config_from_kw(**{
            'game': game,
            'fp_iterations': fp_iterations,
            'temperature': temperature,
            'variant': variant,
            'dataset': dataset,
        })
        env: FastGraphexEnv = dataset_config['game'](**dataset_config, time_var=1, graphex_cutoff=1)

        vals = []
        mf_statess = []
        xsss = []
        run_degreess = []
        for _ in range(num_return_trials):
            val, mf_states, xss, run_degrees = run_once(env, degrees, action_probs_ks, action_probs)
            print(f'{len(run_degrees)}: {vals}', flush=True)
            vals.append(val)
            mf_statess.append(mf_states)
            xsss.append(np.array(xss))
            run_degreess.append(run_degrees)
        np.savez(config['exp_dir'] + f"vals_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz", *vals)
        np.savez(config['exp_dir'] + f"mf_statess_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz", *mf_statess)
        np.savez(config['exp_dir'] + f"xsss_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz", *xsss)
        np.savez(config['exp_dir'] + f"run_degreess_{game, fp_iterations, temperature, variant, num_return_trials, dataset}.npz", *run_degreess)

    prob_ks = [sigma * gamma(k - sigma) / factorial(k, exact=True) / gamma(1 - sigma) for k in degrees]
    diff_all_mu = [np.sum(np.mean(np.abs(
                        np.sum([
                            mus_k.T * prob_k
                            for mus_k, prob_k in zip(mus_ks[1:], prob_ks[1:])
                        ] + [mus.T * (1-np.sum(prob_ks[1:]))], axis=0)
                        - np.array([
                            np.mean((xss == x), -1)
                            for x in range(env.observation_space.n)
                        ])
                    ), axis=1)) / 2
                   for xss, run_degrees in zip(xsss, run_degreess)]

    diff_high_J = v_curr_1 - np.array([np.sum(val * (run_degrees > degree_cutoff)) / (np.sum((run_degrees > degree_cutoff)) + 1e-10) for val, run_degrees in zip(vals, run_degreess)])
    diff_high_mu = [np.sum(np.mean(np.abs(
        mus.T - np.array([
            np.sum((xss == x) * (run_degrees > degree_cutoff), -1) / (np.sum((run_degrees > degree_cutoff), -1) + 1e-10)
            for x in range(env.observation_space.n)
        ])
    ), axis=1)) / 2 for xss, run_degrees in zip(xsss, run_degreess)]
    diff_k_Js = []
    diff_k_mus = []
    for k, v_curr_k, mus_k in zip(degrees, v_curr_ks, mus_ks):
        diff_k_J = v_curr_k - np.array([np.sum(val * (run_degrees == k)) / (np.sum((run_degrees == k)) + 1e-10) for val, run_degrees in zip(vals, run_degreess)])
        diff_k_mu = [np.sum(np.mean(np.abs(
            mus_k.T - np.array([
                np.sum((xss == x) * (run_degrees == k), -1) / (np.sum((run_degrees == k), -1) + 1e-10) for x in range(env.observation_space.n)
            ])
        ), axis=1)) / 2 for xss, run_degrees in zip(xsss, run_degreess)]
        diff_k_Js.append(diff_k_J)
        diff_k_mus.append(diff_k_mu)
    Ns = [len(run_degrees) for run_degrees in run_degreess]

    # normalized between 0 and 1, as the expected total variation
    return diff_all_mu, diff_high_J, diff_high_mu, diff_k_Js, diff_k_mus, Ns


def plot():

    """ Plot figures """
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.size": 24,
        "font.sans-serif": ["Helvetica"],
    })

    games = ['SIS', 'SIR', 'RS', ]
    datasets = ["prosper-loans", "petster-friendships-dog", "soc-pokec-relationships", "livemocha",
                "flickr-growth", "loc-brightkite_edges", "facebook-wosn-wall", "hyves", ]
    sigmas = [0.058, 0.071, 0.108, 0.075, 0.506, 0.376, 0.252, 0.582]
    num_return_trials = 5

    for game in games:
        degree_cutoff = 8 if game != 'RS' else 6
        degrees_to_plot = [2, 3, 4, ]
        degrees = range(degree_cutoff + 1)

        for dataset, sigma in zip(datasets, sigmas):
            diff_all_mu, diff_high_J, diff_high_mu, diff_k_Js, diff_k_mus, Ns = evaluate_N(dataset, sigma, game, 5000, 50., "omd", num_return_trials, degrees, degrees_to_plot, degree_cutoff, rerun=False)

            std_returns = np.std(np.abs(diff_all_mu), axis=0)
            mean_returns = np.mean(np.abs(diff_all_mu), axis=0)
            print(fr"{game} {dataset}: {mean_returns:.4} +- {std_returns:.4}", flush=True)

            std_returns = np.std(np.abs(diff_high_mu), axis=0)
            mean_returns = np.mean(np.abs(diff_high_mu), axis=0)
            print(fr"{game} {dataset}: infty mu {mean_returns:.4} +- {std_returns:.4}", flush=True)

            for idx, k in enumerate(degrees):
                std_returns = np.std(np.abs(diff_k_mus[idx]), axis=0)
                mean_returns = np.mean(np.abs(diff_k_mus[idx]), axis=0)
                print(fr"{game} {dataset}: {k} mu {mean_returns:.4} +- {std_returns:.4}", flush=True)


if __name__ == '__main__':
    plot()
