import fnmatch
import os

import matplotlib.pyplot as plt
import numpy as np
from scipy.special import gamma, factorial

from args_parser import generate_config_from_kw
from fast_marl_graphex import FastGraphexEnv
from utils import eval_curr_reward
from utils import get_curr_mf, get_curr_mf_k


def find(pattern, path):
    result = []
    for root, dirs, files in os.walk(path):
        for name in files:
            if fnmatch.fnmatch(name, pattern):
                result.append(os.path.join(root, name))
    return result


def run_once(env, action_probs):
    """ Plot mean field of states """
    xs = env.reset()
    mf_state = [np.mean(xs == i) for i in range(env.observation_space.n)]
    mf_states = [mf_state]
    xss = [xs]
    uss = []
    val = 0

    cum_ps = np.cumsum(np.array(action_probs), axis=-1)

    for t in range(env.time_steps):
        print(fr'Time {t}', flush=True)
        cum_ps_i = np.expand_dims(cum_ps[t], axis=0)[0, xs] * np.expand_dims(env.degrees > 0, axis=1)

        actions = np.zeros((env.num_agents,), dtype=int)
        uniform_samples = np.random.uniform(0, 1, size=env.num_agents)
        for idx in range(env.action_space.n):
            actions += idx * np.logical_and(
                uniform_samples >= (cum_ps_i[:, idx - 1] if idx - 1 >= 0 else 0.0),
                uniform_samples < cum_ps_i[:, idx])

        xs, rewards, done, info = env.step(actions)
        val += rewards

        xss.append(xs)
        uss.append(actions)

        mf_state = [np.mean(xs == i) for i in range(env.observation_space.n)]
        mf_states.append(mf_state)

    return val, np.array(mf_states), xss, env.degrees


def evaluate_N(dataset, sigma, game, fp_iterations, temperature, variant, num_return_trials, degrees, degrees_to_plot, degree_cutoff, rerun=False, all_same_policy=True):
    config = generate_config_from_kw(**{
        'game': game,
        'fp_iterations': fp_iterations,
        'temperature': temperature,
        'variant': variant,
    })
    env: FastGraphexEnv = config['game'](**config, time_var=1, graphex_cutoff=1)

    action_probs = np.load(config['exp_dir'] + f"action_probs.npy")
    best_response = np.load(config['exp_dir'] + f"best_response.npy")

    """ Compute the best responses and their mean fields for the low degree agents """
    # print("Computing periphery ...")
    mus = get_curr_mf(env, action_probs)
    v_curr_1 = np.vdot(env.mu_0, eval_curr_reward(env, action_probs, mus)[0])

    save_dir = config['exp_dir']
    if not rerun and os.path.exists(config['exp_dir'] + f"mus_ks_all_same.npy"):
        mus_ks_all_same = np.load(config['exp_dir'] + f"mus_ks_all_same.npy")
    else:
        mus_ks_all_same = []
        for k in degrees:
            if k == 0:
                k = 1
            mus_k = get_curr_mf_k(env, action_probs, k, mus, action_probs)
            mus_ks_all_same.append(mus_k)
        mus_ks_all_same = np.array(mus_ks_all_same)
        np.save(config['exp_dir'] + f"mus_ks_all_same.npy", mus_ks_all_same)

    # print("Running trajectories ...")
    if not rerun and os.path.exists(config['exp_dir'] + f"vals_{game, fp_iterations, temperature, variant, num_return_trials, dataset, all_same_policy}.npz"):
        xsss = list(np.load(config['exp_dir'] + f"xsss_{game, fp_iterations, temperature, variant, num_return_trials, dataset, all_same_policy}.npz").values())
        run_degreess = list(np.load(config['exp_dir'] + f"run_degreess_{game, fp_iterations, temperature, variant, num_return_trials, dataset, all_same_policy}.npz").values())
    else:
        dataset_config = generate_config_from_kw(**{
            'game': game,
            'fp_iterations': fp_iterations,
            'temperature': temperature,
            'variant': variant,
            'dataset': dataset,
        })
        env: FastGraphexEnv = dataset_config['game'](**dataset_config, time_var=1, graphex_cutoff=1)

        vals = []
        mf_statess = []
        xsss = []
        run_degreess = []
        for _ in range(num_return_trials):
            val, mf_states, xss, run_degrees = run_once(env, action_probs)
            print(f'{len(run_degrees)}: {vals}', flush=True)
            vals.append(val)
            mf_statess.append(mf_states)
            xsss.append(np.array(xss))
            run_degreess.append(run_degrees)
        np.savez(config['exp_dir'] + f"vals_{game, fp_iterations, temperature, variant, num_return_trials, dataset, all_same_policy}.npz", *vals)
        np.savez(config['exp_dir'] + f"mf_statess_{game, fp_iterations, temperature, variant, num_return_trials, dataset, all_same_policy}.npz", *mf_statess)
        np.savez(config['exp_dir'] + f"xsss_{game, fp_iterations, temperature, variant, num_return_trials, dataset, all_same_policy}.npz", *xsss)
        np.savez(config['exp_dir'] + f"run_degreess_{game, fp_iterations, temperature, variant, num_return_trials, dataset, all_same_policy}.npz", *run_degreess)

    prob_ks = [sigma * gamma(k - sigma) / factorial(k, exact=True) / gamma(1 - sigma) for k in degrees]
    diff_all_mu = [np.sum(np.mean(np.abs(
                        np.sum([
                            mus_k.T * prob_k
                            for mus_k, prob_k in zip(mus_ks_all_same[1:], prob_ks[1:])
                        ] + [mus.T * (1-np.sum(prob_ks[1:]))], axis=0)
                        - np.array([
                            np.mean((xss == x), -1)
                            for x in range(env.observation_space.n)
                        ])
                    ), axis=1)) / 2
                   for xss, run_degrees in zip(xsss, run_degreess)]

    # The LP graphon (for normalized neighborhood mfs) is given simply by the mean
    diff_all_mu_Lp = [np.sum(np.mean(np.abs(
        mus.T - np.array([
            np.sum((xss == x) * (run_degrees > 0), -1) / (np.sum((run_degrees > 0), -1) + 1e-10)
            for x in range(env.observation_space.n)
        ])
    ), axis=1)) / 2 for xss, run_degrees in zip(xsss, run_degreess)]

    # normalized between 0 and 1, as the expected total variation
    return diff_all_mu, diff_all_mu_Lp


def plot():

    """ Plot figures """
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "sans-serif",
        "font.size": 24,
        "font.sans-serif": ["Helvetica"],
    })

    games = ['SIS', 'SIR', 'RS', ]
    datasets = ["prosper-loans", "petster-friendships-dog", "soc-pokec-relationships", "livemocha",
                "flickr-growth", "loc-brightkite_edges", "facebook-wosn-wall", "hyves", ]
    sigmas = [0.058, 0.071, 0.108, 0.075, 0.506, 0.376, 0.252, 0.582]
    num_return_trials = 5

    for game in games:
        degree_cutoff = 8 if game != 'RS' else 6
        degrees_to_plot = [2, 3, 4, ]
        degrees = range(degree_cutoff + 1)

        for dataset, sigma in zip(datasets, sigmas):
            diff_all_mu, diff_all_mu_Lp = evaluate_N(dataset, sigma, game, 5000, 50., "omd", num_return_trials, degrees, degrees_to_plot, degree_cutoff, rerun=False)

            std_returns = np.std(np.abs(diff_all_mu), axis=0)
            mean_returns = np.mean(np.abs(diff_all_mu), axis=0)
            print(fr"{game} {dataset}: all mu {mean_returns:.4} +- {std_returns:.4}", flush=True)

            std_returns = np.std(np.abs(diff_all_mu_Lp), axis=0)
            mean_returns = np.mean(np.abs(diff_all_mu_Lp), axis=0)
            print(fr"{game} {dataset}: all mu Lp {mean_returns:.4} +- {std_returns:.4}", flush=True)


if __name__ == '__main__':
    plot()
