import json
import glob
import os
from typing import Dict, List, Tuple
import itertools
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import tqdm
import csv

MODEL_NAMES = [
    # "Qwen2.5-14B",
    # "Qwen-2.5-14B-SimpleRL-Zoo",
    # "DeepSeek-R1-Distill-Qwen-14B",
    # "AceReason-Nemotron-14B",
    # "Qwen2.5-Math-1.5B",
    # "Qwen2.5-Math-1.5B-Oat-Zero",
    # "DeepSeek-R1-Distill-Qwen-1.5B",
    # "Nemotron-Research-Reasoning-Qwen-1.5B",
    "Qwen2.5-Math-7B",
    "Qwen2.5-Math-7B-Oat-Zero",
    "DeepSeek-R1-Distill-Qwen-7B",
    "AceReason-Nemotron-7B",
]

DATA_NAMES = ["aime25", "amc23"]

GRAPHLET_ORDER = ["P3", "K3", "P4", "Star3", "C4", "TailedTri", "Diamond", "K4"]

def build_model_to_paths() -> Dict[str, List[str]]:
    model_to_paths: Dict[str, List[str]] = {}
    for model in MODEL_NAMES:
        paths: List[str] = []
        for data in DATA_NAMES:
            pattern = (
                f"/XXXX"
                f"{model}_{data}_temp0.6_n256_seed1/eval_results/global_step_0/{data}/"
                f"test_*_-1_seed1_t0.6_s0_e-1_processed.jsonl.processed.jsonl"
            )
            paths.extend(glob.glob(pattern))
        model_to_paths[model] = sorted(paths)
    return model_to_paths

def build_graph_from_paths(path_node_unique: List[List[int]]) -> nx.Graph:
    G = nx.Graph()
    for path in path_node_unique:
        if not isinstance(path, list) or len(path) < 2:
            continue
        for u, v in zip(path[:-1], path[1:]):
            if u == v:
                continue
            G.add_edge(int(u), int(v))
    return G

def count_graphlets(G: nx.Graph) -> np.ndarray:
    counts = np.zeros(8, dtype=np.int64)
    n = G.number_of_nodes()
    if n < 3:
        return counts
    nodes = list(G.nodes())
    adj = {u: set(G[u]) for u in nodes}
    deg = {u: len(adj[u]) for u in nodes}

    order = sorted(nodes, key=lambda u: (deg[u], u))
    rank = {u: i for i, u in enumerate(order)}

    F = {u: [v for v in adj[u] if rank[u] < rank[v]] for u in nodes}
    Fset = {u: set(F[u]) for u in nodes}

    W = 0
    for u in nodes:
        d = deg[u]
        if d >= 2:
            W += d * (d - 1) // 2

    T = 0
    for u in order:
        Fu = Fset[u]
        for v in F[u]:
            T += len(Fu & Fset[v])

    counts[1] = T                
    counts[0] = W - 3 * T        
    if n < 4:
        return counts

    def classify(a, b, c, d):
        abcd = (a, b, c, d)
        degs = [0, 0, 0, 0]
        m = 0
        for i in range(4):
            Ai = adj[abcd[i]]
            for j in range(i + 1, 4):
                if abcd[j] in Ai:
                    m += 1
                    degs[i] += 1
                    degs[j] += 1
        sd = tuple(sorted(degs))
        if m == 3:
            if sd == (1, 1, 2, 2):
                counts[2] += 1  # P4
            elif sd == (1, 1, 1, 3):
                counts[3] += 1  # Star3
        elif m == 4:
            if sd == (2, 2, 2, 2):
                counts[4] += 1  # C4
            elif sd == (1, 2, 2, 3):
                counts[5] += 1  # TailedTri
        elif m == 5:
            if sd == (2, 2, 3, 3):
                counts[6] += 1  # Diamond
        elif m == 6:
            counts[7] += 1      # K4

    for r in tqdm.tqdm(order):
        rr = rank[r]
        C0 = sorted(F[r], key=lambda x: rank[x])
        for i1, v1 in enumerate(C0):
            C1_set = set(C0[i1 + 1:])
            C1_set.update(F[v1])
            C1 = [x for x in C1_set if rank[x] > rr and x != r and x != v1]
            C1.sort(key=lambda x: rank[x])

            for i2, v2 in enumerate(C1):
                C2_set = set(C1[i2 + 1:])
                C2_set.update(F[v2])
                S3 = (r, v1, v2)
                C2 = [x for x in C2_set if rank[x] > rr and x not in S3]
                C2.sort(key=lambda x: rank[x])

                for v3 in C2:
                    classify(r, v1, v2, v3)
    return counts

def read_lines_from_files(paths: List[str]):
    for fp in paths:
        with open(fp, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception:
                    continue
                yield obj

def aggregate_model(model: str, paths: List[str]) -> Tuple[np.ndarray, np.ndarray]:
    per_line_counts = []
    for obj in read_lines_from_files(paths):
        pnu = obj.get("path_node_unique", None)
        if not isinstance(pnu, list):
            continue
        G = build_graph_from_paths(pnu)
        if G.number_of_nodes() == 0:
            continue
        c = count_graphlets(G)
        per_line_counts.append(c.astype(np.float64))
    if len(per_line_counts) == 0:
        return np.zeros(8, dtype=np.float64), np.zeros(8, dtype=np.float64)
    
    arr = np.vstack(per_line_counts)
    avg_counts = arr.mean(axis=0)
    s = avg_counts.sum()
    props = avg_counts / s if s > 0 else np.zeros_like(avg_counts)
    return avg_counts, props

def main():
    model_to_paths = build_model_to_paths()
    for data in DATA_NAMES:
        plt.figure(figsize=(16, 6))
        x = np.arange(len(GRAPHLET_ORDER))
        width = 0.8 / max(1, len(MODEL_NAMES))
        model2props = {}
        model2counts = {}

        for i, model in tqdm.tqdm(enumerate(MODEL_NAMES)):
            avg_counts, props = aggregate_model(model, model_to_paths.get(model, []))
            model2counts[model] = avg_counts
            model2props[model] = props
            plt.bar(x + i * width, props, width=width, label=model)

        g_labels = [f"G{i}" for i in range(1, len(GRAPHLET_ORDER) + 1)]
        plt.xticks(x + (len(MODEL_NAMES) - 1) * width / 2, g_labels)

        plt.ylabel("Proportion")
        plt.title(f"Graphlet Proportions ({data})")
        plt.legend(ncol=2, fontsize=8)
        plt.tight_layout()
        outdir = "XXXX"
        os.makedirs(outdir, exist_ok=True)

        pdf_path = f"{outdir}/all_models_{data}_graphlets.pdf"
        plt.savefig(pdf_path)
        print(f"Saved {pdf_path}")
        plt.close()

        csv_path1 = f"{outdir}/all_models_{data}_graphlets_proportion_7b.csv"
        header = ["Model"] + g_labels
        with open(csv_path1, "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(header)
            for model in MODEL_NAMES:
                props = model2props.get(model, np.zeros(8, dtype=np.float64))
                row = [model] + [f"{v:.10f}" for v in props]
                writer.writerow(row)
        print(f"Saved {csv_path1}")

        csv_path2 = f"{outdir}/all_models_{data}_graphlets_counts_7b.csv"
        with open(csv_path2, "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(header)
            for model in MODEL_NAMES:
                counts = model2counts.get(model, np.zeros(8, dtype=np.float64))
                row = [model] + [f"{v:.0f}" for v in counts]
                writer.writerow(row)
        print(f"Saved {csv_path2}")


if __name__ == "__main__":
    main()
