import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import math
import networks
import time
import utils
import pandas as pd
import datetime
import matplotlib.ticker as ticker
from datasets import Dataset

import ht



def sample_nodes(G, node_strategy=None, seed=None, **kwargs):
    if node_strategy == "first-last":
        u = 1
        v = len(G) - 1

    elif node_strategy == "first-second":
        u = 1
        v = 0

    elif node_strategy == "uniform":
        u, v = np.random.choice(G.nodes, 2, replace=False)

    elif "deg-prod" in node_strategy or "pagerank-prod" in node_strategy:
        if "deg-prod" in node_strategy:
            d = dict(G.degree)

        elif "pagerank-prod" in node_strategy:
            d = nx.pagerank(G)

        else:
            assert(False)

        metric = np.zeros(len(G))
        metric[list(d.keys())] = list(d.values())

        if "invprop" in node_strategy:
            p = 1 / metric
            u, v = np.random.choice(G.nodes, 2, replace=False, p=p / np.sum(p))

        elif "prop" in node_strategy:
            p = metric
            u, v = np.random.choice(G.nodes, 2, replace=False, p=p / np.sum(p))

        else:
            assert(False)

    else:
        assert(False)

    return u, v


@utils.cache
def test_hitting_time(seed=None, verbose=False,
        algorithm="local", num_random_walks=1000, max_len=10000,
        compute_truth=True, **kwargs):
    """
        Tests the hitting time estimation algorithms on a graph.

        Parameters
        ----------
        seed : int
            The seed to use for the random number generator.
        verbose : bool
            If True, prints the results of the test.
        algorithm : str
            The algorithm to use for the hitting time estimation. Can be "local", "sampling", "exact" or "eff_res".
        kwargs : dict
            The parameters to pass to the graph generation function.
    """
    G = networks.get_graph(**kwargs, seed=seed)
    u, v = sample_nodes(G, **kwargs, seed=seed)

    # The ground truth.
    # We compute it using a linear system solver.
    if compute_truth:
        true_hitting_time = ht.exact_ht(G, u, v, sparse=True)
    else:
        true_hitting_time = -1

    walltime = None
    num_samples = None

    # The algorithms to test.
    if algorithm == "exact":
        start_time = time.time()
        hitting_time = ht.exact_ht(G, u, v)
        walltime = time.time() - start_time

    if algorithm == "exact-sparse":
        start_time = time.time()
        hitting_time = ht.exact_ht(G, u, v, sparse=True)
        walltime = time.time() - start_time

    # Cutoff algorithm 
    if algorithm == "cutoff":
        hitting_time, num_samples, walltime = ht.ht_via_cutoff(G, u, v,
                max_len=max_len, num_random_walks=num_random_walks)

    # Collision time algorithm 
    if algorithm == "local-delete":
        hitting_time, num_samples, walltime = ht.estimate_local_ht(G, u, v,
                num_random_walks=num_random_walks, max_len=max_len)

    elif algorithm == "sampling":
        hitting_time, num_samples, walltime = ht.sampling_ht(G, u, v,
                num_random_walks=num_random_walks)

    absolute_error = abs(hitting_time - true_hitting_time)
    relative_error = absolute_error / true_hitting_time

    return {"hitting_time": hitting_time, "true_hitting_time": true_hitting_time,
            "time": walltime, "absolute_error": absolute_error,
            "relative_error": relative_error, "error": relative_error,
            "num_samples": num_samples, "deg_u": G.degree(u),
            "deg_v": G.degree(v), "ix_u": u, "ix_v": v}


map_grp_val_to_label = {
    "('cutoff',)": "Cutoff Algorithm",
    "('exact',)": "Exact Algorithm",
    "('local-delete',)": "Collision Time Algorithm",
    "('sampling',)": "Sampling Algorithm",
}

@utils.savefig
def plot_hitting_time(x, x_variable, y_variables, verbose=False, \
                      savefig=None, groups=None, **kwargs):
    """
        Plots the hitting time estimation results.
        Parameters:
        ----------
        x : huggingface dataset
            The dataframe containing the results of the hitting time estimation.
        x_variable : str
            The variable to plot on the x axis.
        y_variables : list of str
            The variables to plot on the y axis.
        verbose : bool
            If True, prints the results of the test.
        savefig : str
            The name of the file to save the plot to.
        groups : list of str
            The variables to group the results by.
        kwargs : dict
            The parameters to pass to the graph generation function.
    """
    x = x.map(lambda row: test_hitting_time(**row, verbose=verbose), num_proc=utils.NUM_PROC)
    df = x.to_pandas()

    if x_variable is None:
        # Just print the results.
        print(df)
        return

    graph_type = kwargs["graph_type"]
    n = kwargs["n"]

    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    markers = ['o', 'x', 'D', 'v', '<', 's', 'H', '+', 'x', '*', '.', 'X']
    colors, markers = zip(*zip(colors, markers))
    plt.gca().set_prop_cycle(marker=markers, color=colors)

    #
    # Plot x_variable vs mean(y_variable) with fill between +- std(y_variable)
    #
    for y_variable in y_variables:
        for grp_val, grp in df.groupby(groups):
            value = grp.groupby(x_variable)[y_variable]
            plt.plot(value.mean(), label=map_grp_val_to_label[str(grp_val)])
            plt.fill_between(value.mean().index, value.mean() - value.std(), value.mean() + value.std(), alpha=0.06)

            plt.xlabel(x_variable)
            plt.ylabel(y_variable)

        plt.legend()
        plt.tight_layout()
        # plt.title("Hitting time estimation comparison")

        savefig("local_hitting_time_estimation__" + y_variable + "__" + graph_type + "__n_" + str(n))


@utils.savefig
def plot_hitting_time2(x, x_variable, y_variable, verbose=False, savefig=None,
        groups=None, ax=None, single_plot=True, logscale=False):
    x = x.map(lambda row: test_hitting_time(**row, verbose=verbose),
            num_proc=utils.NUM_PROC)
    x = x.map(lambda row: get_graph_info(**row, verbose=verbose),
            num_proc=utils.NUM_PROC)
    df = x.to_pandas()

    if ax is None:
        single_plot = True
        fig, ax = plt.subplots(1, figsize=(4.5, 3.0))
        if logscale:
            ax.set_yscale('log')

    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    markers = ['o', 'x', 'D', 'v', '<', 's', 'H', '+', 'x', '*', '.', 'X']
    colors, markers = zip(*zip(colors, markers))
    plt.gca().set_prop_cycle(marker=markers, color=colors)

    x_val, y_val = None, None

    for grp_val, grp in df.groupby(groups, sort=False):
        value = grp.groupby(x_variable)[y_variable]
        label = utils.get_label(grp_val)

        if grp[x_variable].nunique() > 1:
            ax.plot(value.mean().index, value.mean(), label=label)
            ax.fill_between(value.mean().index, value.mean() - value.std(),
                    value.mean() + value.std(), alpha=0.2)

            x_val = value.mean().index[0]
            y_val = value.mean().iloc[0]

        else:
            line, = ax.plot([x_val], [y_val])
            fill = ax.fill_between([x_val], [y_val], [y_val])
            line.remove()
            fill.remove()
            ax.axhline(value.mean().iloc[0], label=label,
                    color=line.get_color(), marker=line.get_marker())

    ax.set_xlabel(utils.get_label(x_variable))
    ax.set_ylabel(utils.get_label(y_variable))

    if single_plot:
        graph_type = df["graph_type"].iloc[0]
        print(f"{graph_type}: n = {int(df['num_nodes'].mean())}, m = {df['num_edges'].mean():.1f} +- {df['num_edges'].std():.1f}")
        ax.legend()
        fig.tight_layout()
        savefig(f"ht_{graph_type}_{x_variable}_{y_variable}")


@utils.cache
def get_graph_info(ix_u, ix_v, verbose=False, seed=None, **kwargs):
    G = networks.get_graph(**kwargs, seed=seed)
    p = nx.pagerank(G)
    return {"pagerank_u": p[ix_u], "pagerank_v": p[ix_v], "num_edges": len(G.edges), "num_nodes": len(G)}


@utils.savefig
def plot_hitting_time_distr(x, x_variable, y_variable, verbose=False,
        savefig=None, groups=None, ax=None, single_plot=True,
        logscale=True):
    x = x.map(lambda row: test_hitting_time(**row, verbose=verbose),
            num_proc=utils.NUM_PROC)
    x = x.map(lambda row: get_graph_info(**row), num_proc=utils.NUM_PROC)
    df = x.to_pandas()

    if ax is None:
        single_plot = True
        fig, ax = plt.subplots(1, figsize=(4.5, 3.0))
        if logscale:
            ax.set_yscale('log')
            ax.set_xscale('log')

    df["deg_prod"] = df["deg_u"] * df["deg_v"]
    df["deg_ratio"] = df["deg_u"] / df["deg_v"]
    df["pagerank_prod"] = df["pagerank_u"] * df["pagerank_v"]
    df["pagerank_ratio"] = df["pagerank_u"] / df["pagerank_v"]

    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    plt.gca().set_prop_cycle(color=colors)

    for grp_val, grp in df.groupby(groups, sort=False):
        value = grp.groupby(x_variable)[y_variable]
        ax.scatter(value.mean().index, value.mean(),
                label=utils.get_label(grp_val))

    ax.set_xlabel(utils.get_label(x_variable))
    ax.set_ylabel(utils.get_label(y_variable))

    if single_plot:
        graph_type = df["graph_type"].iloc[0]
        ax.legend()
        fig.tight_layout()
        savefig(f"htd_{graph_type}_{x_variable}_{y_variable}")


@utils.savefig
def print_hitting_time_info(x, variables, verbose=False,
        savefig=None, groups=None, ax=None, single_plot=True,
        logscale=True, hist="true_hitting_time", n_bins=10, print_stats=True):
    x = x.map(lambda row: test_hitting_time(**row, verbose=verbose),
            num_proc=utils.NUM_PROC)
    df = x.to_pandas()
    df["bin"] = pd.cut(df[hist], bins=n_bins)

    grp = df.groupby(groups, sort=False)[variables + [hist]]
    if print_stats:
        print("MEAN ", "-" * 100)
        print(grp.mean())
        print("STD  ", "-" * 100)
        print(grp.std())
        graph_type = df["graph_type"].iloc[0]
        print("LATEX", "-" * 100, graph_type)
        for algo, g1 in df.groupby("algorithm"):
            print(utils.get_label(algo))
            for node_strategy, g2 in g1.groupby("node_strategy"):
                val = g2["relative_error"]
                print("&", f"${val.mean():.3f} \\pm {val.std():.3f}$")
            print("\\\\")

    for variable in variables:
        if ax is None:
            single_plot = True
            fig, ax = plt.subplots(1, figsize=(4.5, 3.0))

        colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
        plt.gca().set_prop_cycle(color=colors)
        bins = np.arange(0, n_bins)

        for i, (grp_val, y) in enumerate(grp):
            grouped = y.groupby("bin")[variable]
            summary = grouped.agg(["mean", "std"])

            wid = 0.8 / len(grp)
            ax.bar(0.1 + bins + (i + 0.5) * wid, summary["mean"],
                    yerr=summary["std"], edgecolor="black",
                    label=utils.get_label(grp_val), width=wid)

        ps = np.array([ix.left for ix in summary.index] + [summary.index[-1].right])
        try:
            l10 = int(np.min(np.log10(ps)))
        except:
            l10 = 0
        ps /= 10**l10

        # import pdb; pdb.set_trace()
        plt.xticks(ticks=range(n_bins + 1),
                labels=np.round(ps, 0).astype(int).astype(str))
        ax.set_xlabel(f"{utils.get_label(hist)} ($\\times 10^{l10}$)")
        ax.set_ylabel(utils.get_label(variable))
        ax.set_ylim(bottom=0)

        if single_plot:
            graph_type = df["graph_type"].iloc[0]
            node_strategy = df["node_strategy"].iloc[0]
            ax.legend()
            fig.tight_layout()
            savefig(f"hist_{graph_type}_{node_strategy}_{variable}")


def hitting_time_multicore(x, groups=None, verbose=False):
    df = x.to_pandas()
    cols = set(df.columns)
    cols.remove("cores")

    grps = df.groupby(list(cols))
    grpd_df = grps.max()
    grpd_df["time"] = 0.0
    for ix, grp in grps:
        y = Dataset.from_pandas(grp)
        y = y.map(lambda row: test_hitting_time(**row, verbose=verbose),
                num_proc=32)
        grp = y.to_pandas()
        grpd_df.loc[ix, "time"] = grp["time"].max()

    df = grpd_df.reset_index()
    grp = df.groupby(groups)["time"]
    print(grp.mean())
    print(grp.std())

