# -*- coding: utf-8 -*-
# RQ9 order-sensitivity (DARE) -- standalone figure generator
# Generates:
#   figs/rq9_violin_ce_32B_DARE.pdf
#   figs/rq9_std_heatmap_DARE.pdf
#   figs/rq9_range_bar_0p5_32_72B_DARE.pdf
#
# CSVs expected in current directory:
#   results_dare_0.5B.csv, results_dare_1.5B.csv, results_dare_3B.csv,
#   results_dare_7B.csv, results_dare_14B.csv, results_dare_32B.csv, results_dare_72B.csv
#
# The script:
#   - parses 'model' column (e.g., '1-2-9') to get k
#   - extracts macro-averaged CE as 'avg_ce' (uses Avg/Avg./average/macro if present;
#     otherwise averages across the nine domain columns)
#   - aggregates across orders to compute mean, std, range, CV per (N,k)
#   - draws a violin plot for N=32B, a std heatmap over (N,k), and range bars for N in {0.5, 32, 72}B

import os, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -----------------------
# I/O & basic setup
# -----------------------
os.makedirs("figs", exist_ok=True)
os.makedirs("out", exist_ok=True)

# CSV files to load (you can trim this list if某些规模没有数据)
CSV_LIST = [
    ("0.5",  "DARE/results_dare_0.5B.csv"),
    ("1.5",  "DARE/results_dare_1.5B.csv"),
    ("3",    "DARE/results_dare_3B.csv"),
    ("7",    "DARE/results_dare_7B.csv"),
    ("14",   "DARE/results_dare_14B.csv"),
    ("32",   "DARE/results_dare_32B.csv"),
    ("72",   "DARE/results_dare_72B.csv"),
]

# domain mapping (1..9)
DOMAIN_NAMES = [
    "algebra", "analysis", "discrete", "geometry", "number_theory",
    "biology", "chemistry", "physics", "code"
]

# candidate macro-avg column names (case-insensitive)
MACRO_COL_CANDIDATES = ["avg", "avg.", "average", "macro", "macro_avg", "macro-avg"]

# -----------------------
# Helpers
# -----------------------
def parse_k_from_model(model_str):
    """
    model_str example: '1-2-9' or '3-7' ...
    return k = length of the hyphen-separated list
    """
    if not isinstance(model_str, str):
        return np.nan
    s = model_str.strip()
    if s == "":
        return np.nan
    # split by hyphen or comma
    parts = re.split(r"[-,]", s)
    parts = [p for p in parts if p != ""]
    return len(parts)

def find_macro_avg_col(df):
    cols = [c for c in df.columns]
    lower_map = {c: c.lower().strip() for c in cols}
    for c in cols:
        if lower_map[c] in MACRO_COL_CANDIDATES:
            return c
    return None

def ensure_avg_ce(df):
    """
    Return a Series 'avg_ce':
      - if a macro-avg column exists, use it
      - else average across the 9 domain columns (if present)
    """
    macro_col = find_macro_avg_col(df)
    if macro_col is not None:
        return pd.to_numeric(df[macro_col], errors="coerce")
    # try averaging 9 domains
    present = [c for c in df.columns if c.lower().strip() in DOMAIN_NAMES]
    if len(present) >= 5:  # at least 5 domains present -> average them (robust)
        return pd.to_numeric(df[present], errors="coerce").mean(axis=1)
    # last resort: try any columns that look numeric and average them (unsafe but fallback)
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    if len(numeric_cols) > 0:
        return df[numeric_cols].mean(axis=1)
    raise ValueError("Cannot infer macro-averaged CE; please include Avg/average or 9 domain columns.")

def load_one(csv_path, N_label):
    df = pd.read_csv(csv_path)
    # standardize columns
    df.columns = [c.strip() for c in df.columns]
    # find model/order column heuristically
    model_col = None
    for cand in ["model", "order", "path"]:
        for c in df.columns:
            if c.lower() == cand:
                model_col = c
                break
        if model_col: break
    if model_col is None:
        # try to find a column that looks like hyphenated sequences in many rows
        for c in df.columns:
            if df[c].astype(str).str.contains(r"\d+(?:[-,]\d+)+").mean() > 0.5:
                model_col = c
                break
    if model_col is None:
        raise ValueError(f"[{csv_path}] cannot find 'model' column with hyphenated order.")

    # derive k
    k = df[model_col].apply(parse_k_from_model)

    # macro-avg CE
    avg_ce = ensure_avg_ce(df)

    out = pd.DataFrame({
        "N": float(N_label),
        "k": k,
        "avg_ce": avg_ce
    })
    out = out.dropna(subset=["k", "avg_ce"])
    out["k"] = out["k"].astype(int)
    # keep only k in [1..9]
    out = out[(out["k"] >= 1) & (out["k"] <= 9)]
    return out

def load_all():
    frames = []
    missing = []
    for N_label, fname in CSV_LIST:
        if os.path.exists(fname):
            try:
                frames.append(load_one(fname, N_label))
            except Exception as e:
                print(f"[WARN] Failed parsing {fname}: {e}")
        else:
            missing.append(fname)
    if missing:
        print("[INFO] Missing CSVs (skipped):", missing)
    if not frames:
        raise RuntimeError("No CSV parsed successfully. Please place 'results_dare_*.csv' in cwd.")
    return pd.concat(frames, ignore_index=True)

def agg_dispersion(df_all):
    """
    Compute mean, std, range, CV per (N,k) across orders.
    Return: df_stats (N,k, mean, std, range, cv, count)
    Also return: dict[(N,k)] -> list of avg_ce (for violin plot)
    """
    series_map = {}
    rows = []
    for (N, k), sub in df_all.groupby(["N","k"]):
        vals = sub["avg_ce"].dropna().values
        if len(vals) == 0:
            continue
        mu = float(np.mean(vals))
        std = float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0
        rng = float(np.max(vals) - np.min(vals)) if len(vals) > 1 else 0.0
        cv  = std / mu if mu != 0 else np.nan
        rows.append({
            "N": N, "k": int(k), "mean": mu, "std": std, "range": rng, "cv": cv, "count": int(len(vals))
        })
        series_map[(N, int(k))] = vals
    stats = pd.DataFrame(rows).sort_values(["N","k"]).reset_index(drop=True)
    return stats, series_map

# -----------------------
# Load & aggregate
# -----------------------
df_all = load_all()
stats, series_map = agg_dispersion(df_all)
stats.to_csv("out/rq9_dispersion_summary.csv", index=False)
print("Saved: out/rq9_dispersion_summary.csv")

# -----------------------
# Figure (a): Violin plot @ 32B
# -----------------------
import colorsys
from matplotlib.colors import LinearSegmentedColormap, to_rgb
import matplotlib as mpl
import matplotlib.pyplot as plt

def _low_sat_green_to_blue(n_steps):
    """
    生成低饱和度 绿->蓝 的渐变色列表（按 k 从小到大映射）。
    使用 HSV：H 约 120°(绿) -> 220°(蓝)，S 取 0.25，V 取 0.85。
    """
    if n_steps <= 1:
        h, s, v = 120/360.0, 0.25, 0.85
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        return [(r, g, b)]
    colors = []
    h0, h1 = 120/360.0, 220/360.0
    s, v = 0.25, 0.85
    for i in range(n_steps):
        t = i/(n_steps-1)
        h = h0 + t*(h1 - h0)
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        colors.append((r, g, b))
    return colors

def _darker(rgb, factor=0.75):
    r, g, b = rgb
    return (r*factor, g*factor, b*factor)

def plot_violin_32B(series_map, out_pdf="figs/rq9_violin_ce_32B_DARE.png", N_pick=32.0):
    # ---- 数据 ----
    k_list = [k for (N,k) in series_map.keys() if abs(N - N_pick) < 1e-6]
    if not k_list:
        print(f"[WARN] No data for N={N_pick}B; skip violin.")
        return
    k_sorted = sorted(set(k_list))
    data = [series_map[(N_pick, k)] for k in k_sorted]

    # ---- 配色：按 k 从小到大映射到低饱和度 绿→蓝 ----
    palette = _low_sat_green_to_blue(len(k_sorted))
    cmap = LinearSegmentedColormap.from_list("low_sat_g2b", palette)
    norm = mpl.colors.Normalize(vmin=min(k_sorted), vmax=max(k_sorted))

    # ---- 绘图：显式拿 fig/ax，后面 colorbar 才能绑定到 ax ----
    fig, ax = plt.subplots(figsize=(9, 6))
    parts = ax.violinplot(
        dataset=data,
        positions=k_sorted,
        showmeans=False,
        showmedians=True,
        showextrema=True,
        widths=0.8
    )

    # 给每把小提琴上色（主体与边框）
    for i, pc in enumerate(parts['bodies']):
        c = palette[i]
        pc.set_facecolor(c)
        pc.set_edgecolor(_darker(c, 0.65))
        pc.set_linewidth(0.9)
        pc.set_alpha(0.9)

    # 中位数线、极值线统一样式
    median_color = _darker(to_rgb("#2b6cb0"), 0.85)
    for comp_key in ('cmedians', 'cmins', 'cmaxes', 'cbars'):
        if comp_key in parts and parts[comp_key] is not None:
            comp = parts[comp_key]
            # 兼容不同 matplotlib 版本
            try:
                comp.set_color(median_color)
                comp.set_linewidth(1.2)
            except AttributeError:
                for line in comp:
                    line.set_color(median_color)
                    line.set_linewidth(1.2)

    # 轴与网格
    ax.set_title("Across-order CE distribution by k (DARE, 32B)", fontsize=20)
    ax.set_xlabel("k (experts)", fontsize=17)
    ax.set_ylabel("Macro-avg CE across orders", fontsize=17)
    ax.set_xticks(k_sorted)
    ax.grid(axis="y", alpha=0.25, linestyle="--", linewidth=0.7)

    # # 与 k 对齐的渐变色标（绑定到当前 ax，避免报错）
    # mappable = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
    # mappable.set_array([])  # 某些版本需要非空 array 才能创建色标
    # cbar = fig.colorbar(mappable, ax=ax, pad=0.02, fraction=0.06)
    # cbar.set_label("k (experts)", rotation=90)
    # cbar.set_ticks(k_sorted)
    # cbar.set_ticklabels([str(k) for k in k_sorted])

    fig.tight_layout()
    fig.savefig(out_pdf, dpi=300)
    plt.close(fig)
    print("Saved:", out_pdf)

plot_violin_32B(series_map)

# -----------------------
# Figure (b): Std heatmap over (N,k)
# -----------------------
def plot_std_heatmap(stats, out_pdf="figs/rq9_std_heatmap_DARE.pdf"):
    Ns = sorted(stats["N"].unique())
    ks = sorted(stats["k"].unique())
    H = np.full((len(Ns), len(ks)), np.nan)
    for i, N in enumerate(Ns):
        for j, k in enumerate(ks):
            sub = stats[(stats["N"]==N) & (stats["k"]==k)]
            if len(sub):
                H[i,j] = float(sub["std"].values[0])

    plt.figure(figsize=(9, 6))
    im = plt.imshow(H, aspect="auto", interpolation="nearest")
    plt.colorbar(im, label="Across-order std of macro-avg CE")
    plt.xticks(range(len(ks)), ks)
    plt.yticks(range(len(Ns)), [f"{int(N) if N.is_integer() else N}B" for N in Ns])
    plt.xlabel("k (experts)", fontsize=17)
    plt.ylabel("Model size N", fontsize=17)
    plt.title("Order-induced std over (N,k) (DARE)", fontsize=20)
    plt.tight_layout()
    plt.savefig(out_pdf)
    plt.close()
    print("Saved:", out_pdf)

plot_std_heatmap(stats)

# -----------------------
# Figure (c): Range bars for N in {0.5, 32, 72}B
# -----------------------
def plot_range_bars(stats, out_pdf="figs/rq9_range_bar_0p5_32_72B_DARE.pdf", N_list=(0.5,32.0,72.0)):
    Ns_present = [N for N in N_list if N in set(stats["N"].unique())]
    if not Ns_present:
        print("[WARN] None of the requested Ns present; skip bars.")
        return

    cols = 1
    rows = len(Ns_present)
    plt.figure(figsize=(7.4, 1.8*rows + 0.6))
    for i, N in enumerate(Ns_present, start=1):
        sub = stats[stats["N"]==N].sort_values("k")
        ax = plt.subplot(rows, cols, i)
        ax.bar(sub["k"].astype(int), sub["range"].astype(float))
        ax.set_title(f"Range across orders vs k (N={int(N) if float(N).is_integer() else N}B)")
        ax.set_xlabel("k (experts)")
        ax.set_ylabel("Range (max - min)")
        ax.set_xticks(sub["k"].astype(int).tolist())
        # annotate relative reduction from k=1
        base = float(sub[sub["k"]==1]["range"]) if (sub["k"]==1).any() else None
        if base and base > 0:
            for x, r in zip(sub["k"], sub["range"]):
                rr = float(r)
                pct = 100.0*(base-rr)/base
                ax.text(x, rr, f"{pct:.0f}%", ha="center", va="bottom", fontsize=8)
    plt.tight_layout()
    plt.savefig(out_pdf)
    plt.close()
    print("Saved:", out_pdf)

plot_range_bars(stats)
