import numpy as np
import pandas as pd
import networkx as nx
from tqdm.autonotebook import tqdm
import pickle
import os

np.random.seed(1)

ntrials = 5

# log_dirs = ["results/run_12_additional_data"]
# log_dirs += ["results/run_13_additional_data"]
# log_dirs += ["results/run_15", "results/run_16"]
log_dirs = ["results/"]
df = pd.DataFrame()
for log_dir in log_dirs:
    print("Loading Results From:", log_dir)
    if os.path.isdir(os.path.join(log_dir, "shi_counts")):
        new_df = pd.DataFrame(
            [
                pickle.load(open(os.path.join(log_dir, "shi_counts", f), "rb"))
                for f in tqdm(os.listdir(os.path.join(log_dir, "shi_counts")), desc="Loading Results")
            ]
        )
    else:
        new_df = pd.read_pickle(os.path.join(log_dir, "shi_count_results.pkl"))
    if "Trial" not in new_df.columns:
        new_df["Trial"] = np.arange(len(new_df)) % ntrials
    df = pd.concat([df, new_df], ignore_index=True)
print("Number of Results:", len(df))


df["d"] = df["Layer Widths"].apply(lambda w: w[0])
df["Width"] = df["Layer Widths"].apply(lambda w: w[1])
df["# Hidden"] = df["Layer Widths"].apply(lambda x: len(x) - 2)
df["# Relus"] = df["Layer Widths"].apply(lambda x: sum(x[1:-1]))
df["Average Volume"] = df["Volumes"].map(
    lambda x: sum(y for y in x if y is not None and y <= 10000) / len(x), na_action="ignore"
)
df["Average Inradius"] = df["Inradii"].map(
    lambda x: sum(y for y in x if y is not None and y <= 10000) / len(x), na_action="ignore"
)
df["Average Polys / Second"] = df["# Regions"] / df["Search Time"]
df["Hours"] = df["Search Time"] / 3600
df["% Finite"] = df["Inradii"].map(
    lambda x: sum(y is not None and y != float("inf") for y in x) / len(x), na_action="ignore"
)
group_cols = ["d", "# Hidden", "Width"]


def get_diameter_lb(G, k=50):
    return max(nx.algorithms.approximation.diameter(G, seed=i) for i in range(k))


def get_diameter_ub(G, k=50):
    node_degrees = list(G.degree())
    sorted_degrees = sorted(G.nodes, key=lambda i: node_degrees[i], reverse=True)
    depth_limit = None
    min_ub = float("inf")
    for i in range(k):
        tree = nx.bfs_tree(G, source=sorted_degrees[i % len(sorted_degrees)], depth_limit=depth_limit).to_undirected()
        ub = nx.algorithms.approximation.diameter(tree, seed=1)
        if ub < min_ub:
            min_ub = ub
            depth_limit = ub
    return min_ub


tqdm.pandas(desc="Diameter LB")
df["Diameter LB"] = df["Dual Graph"].progress_apply(get_diameter_lb)
tqdm.pandas(desc="Diameter UB")
df["Diameter UB"] = df["Dual Graph"].progress_apply(get_diameter_ub)


df["Diameter"] = (df["Diameter UB"] + df["Diameter LB"]) / 2

print("Saving DF")
df.to_pickle("plot_results_df.pkl", compression={"method": "gzip", "compresslevel": 1, "mtime": 1})


# i = 7
# e = df["Experiment"].iloc[i]
# G = df["Dual Graph"].iloc[i]
# print(G.number_of_nodes(), G.number_of_edges())

# decomp = Decomp(e.model.to(device="cuda" if torch.cuda.is_available() else "cpu"))
# # decomp.bfs()
# nx.is_isomorphic(decomp.recover_from_dual_graph(G, decomp.point2bv(torch.zeros(e.model.input_shape, device=e.model.device, dtype=e.model.dtype)), source=0), G)

# nt = Network(height="1000px", width="100%")
# nt.from_nx(decomp.plot_dual_graph())
# for n in nt.nodes:
#     if n["id"] in [decomp.poly2index[s] for s in decomp if not s.finite]:
#         n["color"] = "red"
# nt.show_buttons()
# nt.save_graph('nx.html')
