import time
import matplotlib.pyplot as plt
import numpy as np

import utils
import networks
import eff_res



@utils.cache
def test_eff_res(delete=None, node_strategy="first-last", seed=None,
        verbose=False, num_random_walks=1000, max_len=100, **kwargs):

    G = networks.get_graph(**kwargs, seed=seed)
    u = None
    v = None

    if node_strategy == "first-last":
        u = 0
        v = len(G) - 1
    else:
        assert(False)

    true_eff_res = eff_res.exact_eff_res(G, u, v)

    start_time = time.time()
    est_eff_res, num_samples = eff_res.estimate_local_eff_res(G, u, v,
            num_random_walks=num_random_walks, max_len=max_len,
            verbose=verbose)
    est_time = time.time() - start_time

    return {"true_eff_res": true_eff_res, "est_eff_res": est_eff_res,
            "num_samples": num_samples, "time": est_time,
            "error": np.abs(true_eff_res - est_eff_res)}


@utils.savefig
def plot_eff_res(x, x_variable, y_variable, verbose=False, savefig=None,
        groups=None):
    x = x.map(lambda row: test_eff_res(**row, verbose=verbose),
            num_proc=utils.NUM_PROC)
    df = x.to_pandas()

    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    markers = ['o', 'x', 'D', 'v', '<', 's', 'H', '+', 'x', '*', '.', 'X']
    colors, markers = zip(*zip(colors, markers))
    plt.gca().set_prop_cycle(marker=markers, color=colors)

    for grp_val, grp in df.groupby(groups):
        value = grp.groupby(x_variable)[y_variable]
        plt.plot(value.mean().index, value.mean(),
                label=f"{utils.get_label(groups)} = {grp_val}")
        plt.fill_between(value.mean().index, value.mean() - value.std(),
                value.mean() + value.std(), alpha=0.2)
        plt.xlabel(utils.get_label(x_variable))
        plt.ylabel(utils.get_label(y_variable))

    plt.legend()
    plt.tight_layout()
    savefig("local_eff_res")

