from __future__ import annotations
import numpy as np
import networkx as nx
from distances import pairwise_distance_matrix
from diversity import summarize
from typing import List, Tuple
import igraph as ig
import pickle
from joblib import Parallel, delayed

P_SET = [1 / 16, 1 / 8, 1 / 4, 1 / 2, 3 / 4, 7 / 8, 15 / 16]
NUM_NODES = 8
NUM_GRAPHS = 100


def er_fixed_p_set(
    n_graphs: int, n: int, seed: int = 0
) -> Tuple[List[nx.Graph], np.ndarray]:
    
    rng = np.random.default_rng(seed)
    ps = np.array(P_SET, dtype=float)

    base = n_graphs // len(ps)
    rem = n_graphs % len(ps)
    counts = np.full(len(ps), base, dtype=int)
    if rem > 0:
        counts[:rem] += 1 

    graphs: List[nx.Graph] = []
    used_ps: List[float] = []

    for p, k in zip(ps, counts):
        for _ in range(k):
            g = ig.Graph.Erdos_Renyi(n, float(p))
            g = g.to_networkx()
            graphs.append(g)
            used_ps.append(p)

    return graphs, np.array(used_ps, dtype=float)


def run_multiple(
    distance: str,
    workers: Parallel,
    orca_path: str = "orca/orca",
    sanity: bool = False,
    runs=5,
):
    avg_sum = 0
    energy_sum = 0
    idx = 0

    while idx != runs:
        graphs, _ = er_fixed_p_set(n_graphs=NUM_GRAPHS, n=NUM_NODES, seed=0)
        if sanity:
            D = pairwise_distance_matrix(
                graphs, distance="gcd", orca_path=orca_path, workers=workers
            )
            gcd_scores = summarize(D)
            if gcd_scores.energy >= 50:
                continue

        D = pairwise_distance_matrix(
            graphs, distance=distance, orca_path=orca_path, workers=workers
        )
        scores = summarize(D)
        avg_sum += scores.average
        energy_sum += scores.energy
        idx += 1

        if idx % 10 == 0 and idx != runs:
            print(
                f"Ran {idx}\t{distance}: Energy={(energy_sum/(idx)):.6f}  Average={avg_sum/(idx):.6f}"
            )

    print(
        f"Across {runs}, {distance}: Energy={(energy_sum/(idx)):.6f}  Average={avg_sum/(idx):.6f}"
    )


def run_from_pickle(
    pickle_path: str,
    distance: str,
    workers: Paralell,
    orca_path: str = "orca/orca",
    sanity: bool = False,
):
    with open(pickle_path, "rb") as f:
        graphs = pickle.load(f)

    print(f"Loaded {len(graphs)} graphs from {pickle_path}")
    print(f"Computing pairwise distances using {distance}...")

    D = pairwise_distance_matrix(
        graphs, distance=distance, orca_path=orca_path, workers=workers
    )

    scores = summarize(D)

    if sanity and scores.energy >= 1:
        print(f"WARNING: Energy={scores.energy:.6f} >= 1 (sanity check failed)")

    print(f"{distance}: Energy={scores.energy:.6f}  Average={scores.average:.6f}")

    return scores


if __name__ == "__main__":
    workers = Parallel(n_jobs=16)
    print("=" * 60)
    print("Running ER baseline experiments")
    print("=" * 60)
    for dist in [
        "gcd",
        "portrait_div",
        "netlsd_heat",
        "netlsd_wave",
    ]:
        run_multiple(dist, sanity=False, runs=1, workers=workers)
        print()


    print("=" * 60)
    print("Running evaluation on neural dispersion graphs")
    print("=" * 60)

    pickle_path = "../data/diverse_graphs.pkl" 

    for dist in [
        "gcd",
        "portrait_div",
        "netlsd_heat",
        "netlsd_wave",
    ]:
        scores = run_from_pickle(
            pickle_path=pickle_path, distance=dist, sanity=False, workers=workers
        )
        print()
