import glob
import json
from typing import List, Dict, Any, Tuple
from collections import defaultdict
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from networkx.algorithms import community as nx_comm
import tqdm

# MODEL_NAMES = [
#     "Qwen2.5-Math-1.5B",
#     "Qwen2.5-Math-1.5B-Oat-Zero",
#     "DeepSeek-R1-Distill-Qwen-1.5B",
#     "Nemotron-Research-Reasoning-Qwen-1.5B"
# ]

MODEL_NAMES = [
    "Qwen2.5-Math-7B",
    "Qwen2.5-Math-7B-Oat-Zero",
    "DeepSeek-R1-Distill-Qwen-7B",
    "AceReason-Nemotron-7B",
]

# MODEL_NAMES = [
#     "Qwen2.5-14B",
#     "Qwen-2.5-14B-SimpleRL-Zoo",
#     "DeepSeek-R1-Distill-Qwen-14B",
#     "AceReason-Nemotron-14B",
# ]
DATA_NAMES = ["aime24"]

def _collect_files_for_model_data(model: str, data: str) -> List[str]:
    base_pattern = (
        f"/XXXXX"
        f"{model}_{data}_temp0.6_n256_seed1/eval_results/global_step_0/"
        f"{data}/test_*_-1_seed1_t0.6_s0_e-1_processed.jsonl.processed.jsonl"
    )
    return glob.glob(base_pattern)

def build_graph_from_paths(paths: List[List[int]]) -> nx.Graph:
    G = nx.Graph()
    for seq in paths:
        if not isinstance(seq, list) or len(seq) == 0:
            continue
        for v in seq:
            G.add_node(int(v))
        for u, v in zip(seq[:-1], seq[1:]):
            u, v = int(u), int(v)
            if u != v:
                G.add_edge(u, v)
    return G

def giant_component_subgraph(G: nx.Graph) -> nx.Graph:
    if G.number_of_nodes() == 0:
        return G.copy()
    largest = max(nx.connected_components(G), key=len)
    return G.subgraph(largest).copy()

def edge_density(G: nx.Graph) -> float:
    n = G.number_of_nodes()
    if n <= 1:
        return np.nan
    return nx.density(G)

def modularity_greedy(G: nx.Graph) -> float:
    if G.number_of_edges() == 0 or G.number_of_nodes() == 0:
        return np.nan
    try:
        comms = nx_comm.greedy_modularity_communities(G)
        return nx_comm.modularity(G, comms)
    except Exception:
        return np.nan

def assortativity_degree(G: nx.Graph) -> float:
    try:
        return nx.degree_assortativity_coefficient(G)
    except Exception:
        return np.nan

def avg_shortest_path_on_gcc(G: nx.Graph) -> float:
    GCC = giant_component_subgraph(G)
    if GCC.number_of_nodes() < 2 or GCC.number_of_edges() == 0:
        return np.nan
    try:
        return nx.average_shortest_path_length(GCC)
    except Exception:
        return np.nan

def avg_clustering(G: nx.Graph) -> float:
    try:
        return nx.average_clustering(G)
    except Exception:
        return np.nan

def global_efficiency(G: nx.Graph) -> float:
    try:
        return nx.global_efficiency(G)
    except Exception:
        return np.nan

def compare_with_er_random(G: nx.Graph, trials: int = 10, seed: int = 20240920) -> Tuple[float, float]:
    n, m = G.number_of_nodes(), G.number_of_edges()
    if n < 2:
        return (np.nan, np.nan)
    denom = n * (n - 1)
    if denom <= 0:
        return (np.nan, np.nan)
    p = (2.0 * m) / denom
    if not (0 < p < 1):
        return (np.nan, np.nan)

    C_obs = avg_clustering(G)
    L_obs = avg_shortest_path_on_gcc(G)

    rng = np.random.default_rng(seed)
    C_vals, L_vals = [], []
    for _ in range(trials):
        R = nx.erdos_renyi_graph(n, p, seed=int(rng.integers(0, 2**32 - 1)))
        c = avg_clustering(R)
        l = avg_shortest_path_on_gcc(R)
        if not np.isnan(c):
            C_vals.append(c)
        if not np.isnan(l):
            L_vals.append(l)

    C_rand = np.mean(C_vals) if len(C_vals) > 0 else np.nan
    L_rand = np.mean(L_vals) if len(L_vals) > 0 else np.nan

    C_ratio = C_obs / C_rand if (C_rand and not np.isnan(C_rand)) else np.nan
    L_ratio = L_obs / L_rand if (L_rand and not np.isnan(L_rand)) else np.nan
    return C_ratio, L_ratio

def freeman_degree_centralization(G: nx.Graph) -> float:
    n = G.number_of_nodes()
    if n < 3:
        return np.nan
    degs = dict(G.degree())
    d_max = max(degs.values()) if degs else 0
    num = sum(d_max - d for d in degs.values())
    denom = (n - 1) * (n - 2)
    if denom == 0:
        return np.nan
    return num / denom

def gcc_fraction(G: nx.Graph) -> float:
    n = G.number_of_nodes()
    if n == 0:
        return np.nan
    gcc = giant_component_subgraph(G)
    return gcc.number_of_nodes() / n

def transitivity_global(G: nx.Graph) -> float:
    try:
        return nx.transitivity(G)
    except Exception:
        return np.nan

def algebraic_connectivity_gcc(G: nx.Graph) -> float:
    try:
        Gcc = giant_component_subgraph(G)
        if Gcc.number_of_nodes() < 2 or Gcc.number_of_edges() == 0:
            return np.nan
        return nx.algebraic_connectivity(Gcc)
    except Exception:
        return np.nan

def iter_jsonl(filepath: str):
    with open(filepath, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                continue

def extract_paths(rec: Dict[str, Any]) -> List[List[int]]:
    raw = rec.get("path_node_unique", [])
    out: List[List[int]] = []
    if isinstance(raw, list):
        for item in raw:
            if isinstance(item, list):
                seq: List[int] = []
                for v in item:
                    try:
                        seq.append(int(v))
                    except Exception:
                        pass
                if len(seq) > 0:
                    out.append(seq)
    return out

def compute_metrics_for_graph(G: nx.Graph) -> Dict[str, float]:
    dens = edge_density(G)
    mod = modularity_greedy(G)
    assort = assortativity_degree(G)
    c_ratio, l_ratio = compare_with_er_random(G, trials=10, seed=42)
    cent = freeman_degree_centralization(G)
    gccf = gcc_fraction(G)
    eff = global_efficiency(G)
    trans = transitivity_global(G)
    alg_conn = algebraic_connectivity_gcc(G)

    return {
        "edge_density": dens,
        "modularity": mod,
        "assortativity": assort,
        "clustering_over_random": c_ratio,
        "avg_shortest_path_over_random": l_ratio,
        "freeman_centralization": cent,
        "gcc_fraction": gccf,
        "n_nodes": float(G.number_of_nodes()),
        "n_edges": float(G.number_of_edges()),
        "global_efficiency": eff,
        "transitivity": trans,
        "algebraic_connectivity": alg_conn,
    }

def analyze(models: List[str], datas: List[str]) -> Tuple[pd.DataFrame, pd.DataFrame]:
    rows = []
    for model in tqdm.tqdm(models):
        for data in datas:
            files = _collect_files_for_model_data(model, data)
            for fp in files:
                for rec in tqdm.tqdm(iter_jsonl(fp)):
                    paths = extract_paths(rec)
                    G = build_graph_from_paths(paths)
                    met = compute_metrics_for_graph(G)
                    met.update({"model": model, "data": data, "source_file": fp})
                    rows.append(met)

    cols = [
        "model","data","source_file",
        "edge_density","modularity","assortativity",
        "clustering_over_random","avg_shortest_path_over_random",
        "freeman_centralization","gcc_fraction",
        "n_nodes","n_edges", "global_efficiency","transitivity","algebraic_connectivity",
    ]
    df_rows = pd.DataFrame(rows, columns=cols)

    metric_cols = [
        "edge_density","modularity","assortativity",
        "clustering_over_random","avg_shortest_path_over_random",
        "freeman_centralization","gcc_fraction", "global_efficiency","transitivity","algebraic_connectivity",
    ]
    if df_rows.empty:
        df_means = pd.DataFrame(columns=["model"] + metric_cols)
    else:
        df_means = (
            df_rows.groupby("model", as_index=False)[metric_cols]
            .mean(numeric_only=True)
            .sort_values("model")
        )
    return df_rows, df_means

def plot_bars(df_means: pd.DataFrame, out_dir: str):
    os.makedirs(out_dir, exist_ok=True)
    metric_names = [
        "edge_density",
        "modularity",
        "assortativity",
        "clustering_over_random",
        "avg_shortest_path_over_random",
        "freeman_centralization",
        "gcc_fraction",
        "global_efficiency",
        "transitivity",
        "algebraic_connectivity",
    ]
    if df_means.empty:
        print("empty")
        return

    df_plot = df_means.set_index("model").reindex(MODEL_NAMES)

    for metric in metric_names:
        plt.figure()
        df_plot[metric].plot(kind="bar")
        plt.title(metric.replace("_", " "))
        plt.xlabel("model")
        plt.ylabel(metric)
        plt.tight_layout()
        fig_path = os.path.join(out_dir, f"{metric}_bar.png")
        plt.savefig(fig_path)
        plt.close()
        print(f"Saved: {fig_path}")

def main():
    base_out = "/XXXXX"
    os.makedirs(base_out, exist_ok=True)

    SIZE_TOKENS = ["1.5B", "7B", "14B"]

    for data in DATA_NAMES:
        for size in SIZE_TOKENS:
            models_this_size = [m for m in MODEL_NAMES if size in m]
            if len(models_this_size) == 0:
                continue

            out_dir = os.path.join(base_out, f"{data}_{size}")
            os.makedirs(out_dir, exist_ok=True)

            df_rows, df_means = analyze(models_this_size, [data])

            rows_csv  = os.path.join(out_dir, f"per_row_graph_metrics_{data}_{size}.csv")
            means_csv = os.path.join(out_dir, f"per_model_graph_metrics_means_{data}_{size}.csv")
            df_rows.to_csv(rows_csv, index=False)
            df_means.to_csv(means_csv, index=False)
            print(f"Saved per-row metrics: {rows_csv}")
            print(f"Saved per-model means: {means_csv}")

            plot_bars(df_means, out_dir)

    print("Done.")

if __name__ == "__main__":
    main()
