import os, glob, json, re
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from collections import Counter

def _slug(s: str) -> str:
    return re.sub(r'[^0-9A-Za-z\-]+', '_', s)

def load_jsonl(filepath: str) -> List[dict]:
    data = []
    with open(filepath, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def _extract_size_token(model_name: str) -> str:
    m = re.search(r'(\d+(?:\.\d+)?)B', model_name)
    return f"{m.group(1)}B" if m else "mixed"

def _collect_files_for_model_data(model: str, data: str) -> List[str]:
    base_pattern = (
        f"XXXXX"
        f"{model}_{data}_temp0.6_n256_seed1/eval_results/global_step_0/"
        f"{data}/test_*_-1_seed1_t0.6_s0_e-1_processed.jsonl.processed.jsonl"
    )
    return glob.glob(base_pattern)


def _build_graph_from_item(item: dict) -> Tuple[nx.Graph, nx.Graph, Counter]:
    paths = item.get('path_node_unique', [])
    edge_counts = Counter()
    node_visits = Counter()
    nodes = set()

    for path in paths:
        if not path:
            continue
        nodes.update(path)
        node_visits.update(path)
        for i in range(len(path) - 1):
            u, v = path[i], path[i+1]
            if u <= v:
                edge_counts[(u, v)] += 1
            else:
                edge_counts[(v, u)] += 1

    G  = nx.Graph()
    Gw = nx.Graph()
    G.add_nodes_from(nodes)
    Gw.add_nodes_from(nodes)
    for (u, v), w in edge_counts.items():
        G.add_edge(u, v)
        Gw.add_edge(u, v, weight=w)

    return G, Gw, node_visits

def _build_digraph_from_item(item: dict) -> nx.DiGraph:
    paths = item.get('path_node_unique', [])
    D = nx.DiGraph()
    nodes = set()
    for path in paths:
        if not path:
            continue
        nodes.update(path)
        for i in range(len(path) - 1):
            u, v = path[i], path[i+1]
            D.add_edge(u, v)
    D.add_nodes_from(nodes)
    return D

def _sequences_for_item(item: dict) -> Dict[str, np.ndarray]:
    G, Gw, node_visits = _build_graph_from_item(item)
    deg_unw = np.array(sorted((d for _, d in G.degree(weight=None)), reverse=True), dtype=float)
    total_occ = 0
    for path in item.get('path_node_unique', []):
        if path:
            total_occ += len(path)
    if total_occ > 0 and len(node_visits) > 0:
        visit_prob = np.array(sorted((cnt / total_occ for cnt in node_visits.values()),
                                     reverse=True), dtype=float)
    else:
        visit_prob = np.array([], dtype=float)

    D = _build_digraph_from_item(item)

    if len(D) >= 3 and D.number_of_edges() > 0:
        bc = nx.betweenness_centrality(D, normalized=True, weight=None)
        bc_dir = np.array(sorted(bc.values(), reverse=True), dtype=float)
    else:
        bc_dir = np.array([], dtype=float)

    if len(D) >= 2 and D.number_of_edges() > 0:
        ebc = nx.edge_betweenness_centrality(D, normalized=True, weight=None)
        ebc_dir = np.array(sorted(ebc.values(), reverse=True), dtype=float)
    else:
        ebc_dir = np.array([], dtype=float)

    paths = item.get('path_node_unique', [])
    edge_counts = Counter()
    total_edge_occ = 0
    for path in paths:
        if not path:
            continue
        if len(path) >= 2:
            total_edge_occ += (len(path) - 1)
            for i in range(len(path) - 1):
                u, v = path[i], path[i+1]
                if u <= v:
                    edge_counts[(u, v)] += 1
                else:
                    edge_counts[(v, u)] += 1
    if total_edge_occ > 0 and len(edge_counts) > 0:
        edge_freq = np.array(sorted((cnt / total_edge_occ for cnt in edge_counts.values()),
                                    reverse=True), dtype=float)
    else:
        edge_freq = np.array([], dtype=float)

    return {
        "deg_unw": deg_unw,
        "visit_prob": visit_prob,
        "bc_dir": bc_dir,
        "ebc_dir": ebc_dir, 
        "edge_freq": edge_freq,
    }


@dataclass
class FitResult:
    slope: float 
    intercept: float
    r2: float
    n: int
    slope_per_decade: float

def _fit_semilogy(seq: np.ndarray) -> Optional[Tuple[FitResult, np.ndarray, np.ndarray]]:
    if seq is None or len(seq) == 0:
        return None
    y = np.asarray(seq, dtype=float)
    mask = y > 0
    if mask.sum() < 2:
        return None
    y = y[mask]
    r = np.arange(1, len(y) + 1, dtype=float)

    z = np.log(y)
    x = r
    n = len(x)

    x_mean = x.mean()
    z_mean = z.mean()
    Sxx = ((x - x_mean)**2).sum()
    if Sxx <= 0:
        return None
    Sxz = ((x - x_mean) * (z - z_mean)).sum()

    beta = Sxz / Sxx
    alpha = z_mean - beta * x_mean

    z_hat = alpha + beta * x
    sst = ((z - z_mean)**2).sum()
    sse = ((z - z_hat)**2).sum()
    r2 = 1.0 - (sse / sst) if sst > 0 else 0.0

    y_fit = np.exp(z_hat)
    fit = FitResult(
        slope=beta,
        intercept=alpha,
        r2=r2,
        n=n,
        slope_per_decade=beta / np.log(10.0)
    )
    return fit, r, y_fit

def _fit_loglog(seq: np.ndarray) -> Optional[Tuple[FitResult, np.ndarray, np.ndarray]]:
    if seq is None or len(seq) == 0:
        return None
    y = np.asarray(seq, dtype=float)
    mask = y > 0
    if mask.sum() < 2:
        return None
    y = y[mask]
    r = np.arange(1, len(y) + 1, dtype=float)

    z = np.log(y)
    x = np.log(r)
    n = len(x)

    x_mean = x.mean()
    z_mean = z.mean()
    Sxx = ((x - x_mean)**2).sum()
    if Sxx <= 0:
        return None
    Sxz = ((x - x_mean) * (z - z_mean)).sum()

    beta = Sxz / Sxx
    alpha = z_mean - beta * x_mean

    z_hat = alpha + beta * x
    sst = ((z - z_mean)**2).sum()
    sse = ((z - z_hat)**2).sum()
    r2 = 1.0 - (sse / sst) if sst > 0 else 0.0

    y_fit = np.exp(z_hat)
    fit = FitResult(
        slope=beta,
        intercept=alpha,
        r2=r2,
        n=n,
        slope_per_decade=beta
    )
    return fit, r, y_fit

def _plot_semilogy_with_fit(ax, seq: np.ndarray, title: str, xlabel: str, ylabel: str):
    res = _fit_semilogy(seq)
    if res is None:
        ax.text(0.5, 0.5, "insufficient positive data", ha='center', va='center')
        ax.set_title(title)
        ax.set_xlabel(xlabel); ax.set_ylabel(ylabel)
        return None

    fit, r, y_fit = res
    y = np.asarray(seq, dtype=float)
    mask = y > 0
    y_pos = y[mask]
    r_pos = np.arange(1, len(y_pos) + 1, dtype=float)
    ax.semilogy(r_pos, y_pos, marker='.', linestyle='None', label='data')
    ax.semilogy(r, y_fit, linestyle='-', label=f'fit (β={fit.slope:.4g}, R^2={fit.r2:.3f})')
    ax.set_title(title)
    ax.set_xlabel(xlabel); ax.set_ylabel(ylabel)
    ax.set_xlim(left=1)
    ax.legend(loc='best')
    return fit

def _plot_loglog_with_fit(ax, seq: np.ndarray, title: str, xlabel: str, ylabel: str):
    res = _fit_loglog(seq)
    if res is None:
        ax.text(0.5, 0.5, "insufficient positive data", ha='center', va='center')
        ax.set_title(title)
        ax.set_xlabel(xlabel); ax.set_ylabel(ylabel)
        return None

    fit, r, y_fit = res
    y = np.asarray(seq, dtype=float)
    mask = y > 0
    y_pos = y[mask]
    r_pos = np.arange(1, len(y_pos) + 1, dtype=float)
    ax.loglog(r_pos, y_pos, marker='.', linestyle='None', label='data')
    ax.loglog(r, y_fit, linestyle='-', label=f'fit (β={fit.slope:.4g}, R²={fit.r2:.3f})')
    ax.set_title(title)
    ax.set_xlabel(xlabel); ax.set_ylabel(ylabel)
    ax.set_xlim(left=1)
    ax.legend(loc='best')
    return fit

# MODEL_NAMES = [
#     "Qwen2.5-Math-1.5B",
#     "Qwen2.5-Math-1.5B-Oat-Zero",
#     "DeepSeek-R1-Distill-Qwen-1.5B",
#     "Nemotron-Research-Reasoning-Qwen-1.5B"
# ]

# MODEL_NAMES = [
#     "Qwen2.5-Math-7B",
#     "Qwen2.5-Math-7B-Oat-Zero",
#     "DeepSeek-R1-Distill-Qwen-7B",
#     "AceReason-Nemotron-7B",
# ]

MODEL_NAMES = [
    "Qwen2.5-14B",
    "Qwen-2.5-14B-SimpleRL-Zoo",
    "DeepSeek-R1-Distill-Qwen-14B",
    "AceReason-Nemotron-14B",
]

DATA_NAMES = ["amc23"]
SAVE_DIR = "XXXX"
os.makedirs(SAVE_DIR, exist_ok=True)

agg_slopes: Dict[str, Dict[str, List[float]]] = {
    m: {"degree": [], "frequency": [], "betweenness": [], "edge_betweenness": [], "edge_frequency": []}
    for m in MODEL_NAMES
}
all_slope_rows: List[Dict[str, object]] = []
FIG_DIR = os.path.join(SAVE_DIR, "fits_semilogy")
os.makedirs(FIG_DIR, exist_ok=True)

for model in MODEL_NAMES:
    for data in DATA_NAMES:
        files = _collect_files_for_model_data(model, data)
        if not files:
            print(f"[{model}, {data}] No files")
            continue
        print(f"[{model}, {data}] {len(files)} files")

        for fp in files:
            try:
                items = load_jsonl(fp)
            except Exception as e:
                print(f"  -> load error: {fp}: {e}")
                continue

            for row_idx, item in enumerate(items):
                seqs = _sequences_for_item(item)
                fig, axes = plt.subplots(1, 5, figsize=(35, 6))
                fits: Dict[str, Optional[FitResult]] = {}
                fits['degree'] = _plot_semilogy_with_fit(
                    axes[0], seqs['deg_unw'],
                    title="Degree Rank (semilogy)",
                    xlabel="Rank (linear)", ylabel="Value (log)"
                )
                fits['frequency'] = _plot_semilogy_with_fit(
                    axes[1], seqs['visit_prob'],
                    title=r"Visit Probability Rank (semilogy)",
                    xlabel="Rank (linear)", ylabel=r"$p_i$ (log)"
                )
                fits['betweenness'] = _plot_semilogy_with_fit(
                    axes[2], seqs['bc_dir'],
                    title="Directed Betweenness Rank (semilogy)",
                    xlabel="Rank (linear)", ylabel="Betweenness (log)"
                )
                fits['edge_betweenness'] = _plot_semilogy_with_fit(
                    axes[3], seqs['ebc_dir'],
                    title="Directed Edge Betweenness Rank (semilogy)",
                    xlabel="Rank (linear)", ylabel="Edge Betweenness (log)"
                )
                fits['edge_frequency'] = _plot_loglog_with_fit(
                    axes[4], seqs['edge_freq'],
                    title=r"Edge Frequency Rank (loglog)",
                    xlabel="Rank (log)", ylabel=r"$q_e$ (log)"
                )


                for k in ['degree', 'frequency', 'betweenness', 'edge_betweenness', 'edge_frequency']:
                    if fits.get(k) is not None:
                        agg_slopes[model][k].append(fits[k].slope)
                        all_slope_rows.append({
                            "model": model,
                            "data": data,
                            "metric": k,
                            "file": os.path.basename(fp),
                            "row_idx": row_idx,
                            "beta": float(fits[k].slope),
                            "beta_per_decade": float(fits[k].slope_per_decade),
                            "r2": float(fits[k].r2),
                            "n_points": int(fits[k].n)
                        })

                fig.suptitle(f"{model} / {data} / row={row_idx} / file={os.path.basename(fp)}", y=1.02, fontsize=12)
                fig.tight_layout()
                size_token = _extract_size_token(model)
                
                out_png = os.path.join(
                    FIG_DIR,
                    f"{_slug(model)}_{size_token}_{data}_row{row_idx}_{_slug(os.path.basename(fp))}.png"
                )
                plt.savefig(out_png, dpi=130, bbox_inches='tight')
                plt.close(fig)
size_token = _extract_size_token(model)
summary_rows = []
for model in MODEL_NAMES:
    for metric in ['degree', 'frequency', 'betweenness', 'edge_betweenness','edge_frequency']:
        vals = np.array(agg_slopes[model][metric], dtype=float)
        if vals.size == 0:
            mean_beta = np.nan
            std_beta  = np.nan
            n = 0
        else:
            mean_beta = float(vals.mean())
            std_beta  = float(vals.std(ddof=1)) if vals.size >= 2 else 0.0
            n = int(vals.size)
        if n > 0:
            if metric == 'edge_frequency':
                mean_beta_decade = mean_beta_decade
            else:
                mean_beta_decade = mean_beta / np.log(10.0)
        else:
            mean_beta_decade = np.nan

        print(f"{model:35s}  {metric:12s}  n={n:5d}  mean β={mean_beta: .6g}  (per decade={mean_beta_decade: .6g})  std={std_beta: .3g}")
        summary_rows.append({
            "model": model,
            "metric": metric,
            "n": n,
            "mean_beta": mean_beta,
            "std_beta": std_beta,
            "mean_beta_per_decade": mean_beta_decade
        })

import csv, json as _json
CSV_PATH  = os.path.join(SAVE_DIR, f"model_mean_slopes_semilogy_{data}_{size_token}.csv")
JSON_PATH = os.path.join(SAVE_DIR, f"model_mean_slopes_semilogy_{data}_{size_token}.json")
with open(CSV_PATH, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=list(summary_rows[0].keys()) if summary_rows else
                            ["model","metric","n","mean_beta","std_beta","mean_beta_per_decade"])
    writer.writeheader()
    writer.writerows(summary_rows)
with open(JSON_PATH, "w") as f:
    _json.dump(summary_rows, f, ensure_ascii=False, indent=2)

ALL_CSV_PATH  = os.path.join(SAVE_DIR, f"model_all_slopes_semilogy_{data}_{size_token}.csv")
ALL_JSON_PATH = os.path.join(SAVE_DIR, f"model_all_slopes_semilogy_{data}_{size_token}.json")
with open(ALL_CSV_PATH, "w", newline="") as f:
    fieldnames = ["model","data","metric","file","row_idx","beta","beta_per_decade","r2","n_points"]
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(all_slope_rows)
with open(ALL_JSON_PATH, "w") as f:
    _json.dump(all_slope_rows, f, ensure_ascii=False, indent=2)

print(f"\n[DONE] Summary saved:\n - {CSV_PATH}\n - {JSON_PATH}\n"
      f"All-sample betas saved:\n - {ALL_CSV_PATH}\n - {ALL_JSON_PATH}\n"
      f"Figures saved under: {FIG_DIR}")
