# -*- coding: utf-8 -*-
# RQ9 order-sensitivity (DARE) -- standalone figure generator
# Generates:
#   figs/rq9_range_bar_0p5_32_72B_DARE.pdf
#
# CSVs expected in current directory:
#   results_dare_0.5B.csv, results_dare_1.5B.csv, results_dare_3B.csv,
#   results_dare_7B.csv, results_dare_14B.csv, results_dare_32B.csv, results_dare_72B.csv
#
# The script:
#   - parses 'model' column (e.g., '1-2-9') to get k
#   - extracts macro-averaged CE as 'avg_ce'
#   - aggregates across orders to compute mean, std, range, CV per (N,k)
#   - draws range bars for N in {0.5, 32, 72}B

import os, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -----------------------
# I/O & basic setup
# -----------------------
os.makedirs("figs", exist_ok=True)
os.makedirs("out", exist_ok=True)

CSV_LIST = [
    ("0.5",  "DARE/results_dare_0.5B.csv"),
    ("1.5",  "DARE/results_dare_1.5B.csv"),
    ("3",    "DARE/results_dare_3B.csv"),
    ("7",    "DARE/results_dare_7B.csv"),
    ("14",   "DARE/results_dare_14B.csv"),
    ("32",   "DARE/results_dare_32B.csv"),
    ("72",   "DARE/results_dare_72B.csv"),
]

DOMAIN_NAMES = [
    "algebra", "analysis", "discrete", "geometry", "number_theory",
    "biology", "chemistry", "physics", "code"
]
MACRO_COL_CANDIDATES = ["avg", "avg.", "average", "macro", "macro_avg", "macro-avg"]

# -----------------------
# Helpers
# -----------------------
def parse_k_from_model(model_str):
    if not isinstance(model_str, str):
        return np.nan
    s = model_str.strip()
    if s == "":
        return np.nan
    parts = re.split(r"[-,]", s)
    parts = [p for p in parts if p != ""]
    return len(parts)

def find_macro_avg_col(df):
    lower_map = {c: c.lower().strip() for c in df.columns}
    for c in df.columns:
        if lower_map[c] in MACRO_COL_CANDIDATES:
            return c
    return None

def ensure_avg_ce(df):
    macro_col = find_macro_avg_col(df)
    if macro_col is not None:
        return pd.to_numeric(df[macro_col], errors="coerce")
    present = [c for c in df.columns if c.lower().strip() in DOMAIN_NAMES]
    if len(present) >= 5:
        return pd.to_numeric(df[present], errors="coerce").mean(axis=1)
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    if len(numeric_cols) > 0:
        return df[numeric_cols].mean(axis=1)
    raise ValueError("Cannot infer macro-averaged CE")

def load_one(csv_path, N_label):
    df = pd.read_csv(csv_path)
    df.columns = [c.strip() for c in df.columns]

    model_col = None
    for cand in ["model", "order", "path"]:
        for c in df.columns:
            if c.lower() == cand:
                model_col = c
                break
        if model_col: break
    if model_col is None:
        for c in df.columns:
            if df[c].astype(str).str.contains(r"\d+(?:[-,]\d+)+").mean() > 0.5:
                model_col = c
                break
    if model_col is None:
        raise ValueError(f"[{csv_path}] cannot find 'model' column.")

    k = df[model_col].apply(parse_k_from_model)
    avg_ce = ensure_avg_ce(df)

    out = pd.DataFrame({
        "N": float(N_label),
        "k": k,
        "avg_ce": avg_ce
    }).dropna(subset=["k", "avg_ce"])
    out["k"] = out["k"].astype(int)
    return out[(out["k"] >= 1) & (out["k"] <= 9)]

def load_all():
    frames, missing = [], []
    for N_label, fname in CSV_LIST:
        if os.path.exists(fname):
            try:
                frames.append(load_one(fname, N_label))
            except Exception as e:
                print(f"[WARN] Failed parsing {fname}: {e}")
        else:
            missing.append(fname)
    if missing:
        print("[INFO] Missing CSVs (skipped):", missing)
    if not frames:
        raise RuntimeError("No CSV parsed successfully.")
    return pd.concat(frames, ignore_index=True)

def agg_dispersion(df_all):
    rows = []
    for (N, k), sub in df_all.groupby(["N","k"]):
        vals = sub["avg_ce"].dropna().values
        if len(vals) == 0: continue
        mu = float(np.mean(vals))
        std = float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0
        rng = float(np.max(vals) - np.min(vals)) if len(vals) > 1 else 0.0
        cv  = std / mu if mu != 0 else np.nan
        rows.append({"N": N, "k": int(k), "mean": mu,
                     "std": std, "range": rng, "cv": cv, "count": int(len(vals))})
    return pd.DataFrame(rows).sort_values(["N","k"]).reset_index(drop=True)

# -----------------------
# Load & aggregate
# -----------------------
df_all = load_all()
stats = agg_dispersion(df_all)
stats.to_csv("out/rq9_dispersion_summary.csv", index=False)
print("Saved: out/rq9_dispersion_summary.csv")

import colorsys
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

def _low_sat_green_to_blue(n_steps):
    """
    生成低饱和度 绿->蓝 的渐变色列表（按 k 从小到大映射）。
    使用 HSV：H 约 120°(绿) -> 220°(蓝)，S 取低值 0.25，V 取 0.85。
    """
    if n_steps <= 1:
        h, s, v = 120/360.0, 0.25, 0.85
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        return [(r, g, b)]
    colors = []
    h0, h1 = 120/360.0, 220/360.0
    s, v = 0.25, 0.85
    for i in range(n_steps):
        t = i/(n_steps-1)
        h = h0 + t*(h1 - h0)
        r, g, b = colorsys.hsv_to_rgb(h, s, v)
        colors.append((r, g, b))
    return colors

def plot_range_bars(stats, out_pdf="figs/rq9_range_bar_0p5_32_72B_DARE.png", N_list=(0.5,32.0,72.0)):
    # 过滤存在的 N
    Ns_present = [N for N in N_list if N in set(stats["N"].unique())]
    if not Ns_present:
        print("[WARN] None of the requested Ns present; skip bars.")
        return

    # 统一的 k 轴（共享 X 轴）
    # 取所有子图出现过的 k 的并集并排序，确保共享刻度一致
    all_k = sorted(stats[stats["N"].isin(Ns_present)]["k"].astype(int).unique().tolist())

    # 生成按 k 顺序的渐变色（低饱和度绿->蓝）
    k_to_color = {}
    grad_colors = _low_sat_green_to_blue(len(all_k))
    for k, c in zip(all_k, grad_colors):
        k_to_color[k] = c

    # 画图：共享 X 轴，只有最下面显示 X 轴
    fig, axes = plt.subplots(
        nrows=len(Ns_present),
        ncols=1,
        sharex=True,
        figsize=(9, 6)
    )

    # 兼容当只有一个子图时 axes 不是列表的情况
    if not isinstance(axes, np.ndarray):
        axes = np.array([axes])

    y_lims = [1.01, 0.51, 0.51]
    for i, (ax, N) in enumerate(zip(axes, Ns_present), start=1):
        sub = stats[stats["N"]==N].copy()
        sub["k"] = sub["k"].astype(int)
        sub = sub.sort_values("k")

        # 对齐到统一 all_k（缺失的 k 用 0 高度，以保证刻度一致且颜色映射稳定）
        range_map = {int(k): float(r) for k, r in zip(sub["k"], sub["range"].astype(float))}
        heights = [range_map.get(k, 0.0) for k in all_k]
        bar_colors = [k_to_color[k] for k in all_k]

        # 柱形 + 黑色边框(2)
        bars = ax.bar(all_k, heights, color=bar_colors, edgecolor="black", linewidth=2)

        # 标题与 Y 轴
        N_txt = int(N) if float(N).is_integer() else N
        ax.set_title(f"Range across orders vs k (N={N_txt}B)", fontsize=20)
        
        if i == 2:
            ax.set_ylabel("Range (max - min)", fontsize=17)

        # 仅最底部子图显示 X 轴/标签/刻度；上面的隐藏
        if i < len(Ns_present):
            ax.tick_params(axis="x", which="both", bottom=False, labelbottom=False)
            # 上面子图的底部脊线也隐藏，避免重复视觉元素
            # ax.spines["bottom"].set_visible(False)
        else:
            ax.set_xlabel("k (experts)", fontsize=17)
            ax.set_xticks(all_k)
        ax.set_ylim(y_lims[i-1])

        # 轴框线加粗(3) + 边框统一黑色
        for spine in ax.spines.values():
            spine.set_linewidth(2)
            spine.set_edgecolor("black")

        # 百分比标注（相对于 k=1 的提升/下降），仅在有 k=1 且其范围>0 时
        base = range_map.get(1, None)
        if base is not None and base > 0:
            for x, r in zip(all_k, heights):
                if r is None or x == 1:
                    continue
                pct = 100.0*(base - float(r))/base
                # 仅对存在数据的柱子标注（r>0）
                if r > 0:
                    ax.text(x, r, f"⬇{pct:.0f}%", ha="center", va="bottom", fontsize=12)

        # 提高柱子与坐标轴之间的可读性
        ax.margins(x=0.02)

    plt.tight_layout()
    plt.savefig(out_pdf, dpi=300, bbox_inches="tight")
    plt.close()
    print("Saved:", out_pdf)

plot_range_bars(stats)