# -*- coding: utf-8 -*-
# RQ13 (LLaMA backbones) — robust loader + compact plotting (styled like "code2")
# 逻辑保持：数据加载/拟合/汇总 与原版一致；仅重构与美化绘图
# -------------------------------------------------------------------------------
# Inputs (cwd):
#   results_llama_3B.csv
#   results_llama_8B.csv
#
# Outputs:
#   figs/rq13_llama_marginal_gain.png
#   figs/rq13_llama_kstar_8090.png
#   out/rq13_llama_fit_summary.csv
#
import os, re, numpy as np, pandas as pd, matplotlib.pyplot as plt

# =============================
# 0) 全局样式（模仿“代码2”的风格）
# =============================
from matplotlib import cycler

def apply_global_style():
    plt.rcParams.update({
        # 画布 & 分辨率
        "figure.dpi": 300, "savefig.dpi": 300,
        # 字体 & 尺寸
        "font.size": 14,
        "axes.titlesize": 17,
        "axes.labelsize": 15,
        "xtick.labelsize": 13,
        "ytick.labelsize": 13,
        "legend.fontsize": 13,
        # 线条 & 标记
        "lines.linewidth": 2.2,
        "lines.markersize": 7.0,
        # 网格（默认关；需要时单图开启）
        "axes.grid": False,
        "grid.linestyle": "--",
        "grid.alpha": 0.25,
        "grid.linewidth": 1.0,
        # 图例
        "legend.frameon": True,
        "legend.framealpha": 0.95,
        "legend.edgecolor": "#cccccc",
        # 坐标轴脊线（保留四边，配合加粗函数）
        "axes.spines.top": True,
        "axes.spines.right": True,
    })
    base_colors = ["#4C78A8", "#EA6AA8", "#54A24B", "#E45756", "#72B7B2"]
    plt.rcParams["axes.prop_cycle"] = cycler(color=base_colors)

def bold_spines(ax, width=2.4):
    for side in ["left", "right", "bottom", "top"]:
        ax.spines[side].set_linewidth(width)

apply_global_style()

# 基本 I/O
os.makedirs("figs", exist_ok=True)
os.makedirs("out",  exist_ok=True)

# -------------------------------
# 1) Robust CSV loader（与原逻辑一致）
# -------------------------------
K_TARGET = np.arange(1, 10)  # 1..9

def _looks_like_k(col, s):
    name = col.lower()
    if name in {"k","num_experts","n_experts","experts","domains","n_domains","poolm","m"}:
        try:
            vals = pd.to_numeric(s, errors="coerce").dropna().unique()
            return np.isin(vals, K_TARGET).mean() >= 0.8
        except Exception:
            return False
    try:
        vals = pd.to_numeric(s, errors="coerce").dropna().unique()
        return (np.isin(vals, K_TARGET).sum() >= 5)
    except Exception:
        return False

def _pick_value_col(df, exclude):
    pri = [c for c in df.columns if any(tok in c.lower() for tok in
           ["macro_ce","macro loss","macro_loss","ce","loss","nll","metric","value","mean"])]
    cand = [c for c in pri if c not in exclude and pd.api.types.is_numeric_dtype(df[c])]
    if cand:
        return cand[0]
    num_cols = [c for c in df.columns if c not in exclude and pd.api.types.is_numeric_dtype(df[c])]
    if num_cols:
        return num_cols[-1]
    return None

def _wide_try_map_kcols(cols):
    m = {}
    for c in cols:
        lc = c.lower()
        mobj = re.search(r'(^|[^0-9])k\s*[_=\-]?\s*([1-9])([^0-9]|$)', lc)
        if mobj:
            kin = int(mobj.group(2)); m[kin] = c; continue
        if re.fullmatch(r'\s*[1-9]\s*', lc):
            kin = int(lc.strip()); m[kin] = c; continue
        mobj2 = re.search(r'[_\-]k\s*([1-9])$', lc)
        if mobj2:
            kin = int(mobj2.group(1)); m[kin] = c; continue
    if all(k in m for k in range(1,10)):
        return m
    if len(m) >= 7:
        return m
    return {}

def load_series(csv_path):
    df = pd.read_csv(csv_path)
    # (A) Long
    kcol = None
    for c in df.columns:
        if _looks_like_k(c, df[c]):
            kcol = c; break
    if kcol is not None:
        vcol = _pick_value_col(df, exclude={kcol})
        if vcol is not None:
            sub = df[[kcol, vcol]].copy()
            sub[kcol] = pd.to_numeric(sub[kcol], errors="coerce")
            sub[vcol] = pd.to_numeric(sub[vcol], errors="coerce")
            sub = sub.dropna()
            g = sub.groupby(kcol, as_index=False)[vcol].mean().sort_values(kcol)
            kuniq = g[kcol].unique()
            if len(kuniq) >= 9:
                pick = np.sort(kuniq)[:9]
                y = [float(g[g[kcol]==kk][vcol].mean()) for kk in pick]
                return K_TARGET, np.array(y, float)
            if set(K_TARGET).issubset(set(kuniq)):
                y = [float(g[g[kcol]==kk][vcol].mean()) for kk in K_TARGET]
                return K_TARGET, np.array(y, float)
            if 7 <= len(kuniq) < 9:
                from numpy import interp
                xs = np.sort(kuniq.astype(float))
                ys = g.set_index(kcol)[vcol].loc[xs].to_numpy(float)
                y_interp = interp(K_TARGET, xs, ys)
                return K_TARGET, y_interp
    # (B) Wide
    kmap = _wide_try_map_kcols(list(df.columns))
    if kmap:
        vals = []
        for kin in range(1,10):
            col = kmap.get(kin, None)
            if col is None:
                vals.append(np.nan)
            else:
                vals.append(pd.to_numeric(df[col], errors="coerce").mean())
        y = np.array(vals, float)
        if np.isnan(y).any():
            idx = np.arange(1,10)
            good = ~np.isnan(y)
            y = np.interp(idx, idx[good], y[good])
        return K_TARGET, y
    # (C) Row-of-9
    num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
    if len(df) >= 1 and len(num_cols) >= 9:
        y = df.iloc[0][num_cols[:9]].astype(float).to_numpy()
        return K_TARGET, y
    # (D) Column-of-9
    if len(df.columns) == 1 and pd.api.types.is_numeric_dtype(df.iloc[:,0]):
        if len(df) >= 9:
            y = df.iloc[:9, 0].astype(float).to_numpy()
            return K_TARGET, y
    raise ValueError(f"Unrecognized format in {csv_path}. Columns: {list(df.columns)}")

# -------------------------------
# 2) Floor+tail fit + bootstrap（与原逻辑一致）
# -------------------------------
def fit_floor_tail_for_b(k, y, b):
    x = 1.0/(k+b)
    X = np.vstack([np.ones_like(x), x]).T
    coef, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
    L_inf, A = float(coef[0]), float(coef[1])
    yhat = X @ coef
    sse = float(np.sum((y - yhat)**2))
    return L_inf, A, yhat, sse

def choose_b_grid(k, y, grid=np.linspace(0,1,401)):
    best_b, best_sse = None, np.inf
    for b in grid:
        _,_,_,sse = fit_floor_tail_for_b(k, y, b)
        if sse < best_sse:
            best_b, best_sse = float(b), sse
    return best_b

def r2_score(y, yhat):
    ss_res = np.sum((y-yhat)**2)
    ss_tot = np.sum((y-np.mean(y))**2)
    return 1.0 - ss_res/ss_tot if ss_tot>0 else 1.0

def bootstrap_fit(k, y, B=1000, refit_b=True, seed=42):
    b0 = choose_b_grid(k, y)
    L0, A0, yhat0, _ = fit_floor_tail_for_b(k, y, b0)
    res = y - yhat0
    rng = np.random.default_rng(seed)
    b_s, L_s, A_s = [], [], []
    for _ in range(B):
        y_star = yhat0 + rng.choice(res, size=len(res), replace=True)
        b_star = choose_b_grid(k, y_star) if refit_b else b0
        Ls, As, _, _ = fit_floor_tail_for_b(k, y_star, b_star)
        b_s.append(b_star); L_s.append(Ls); A_s.append(As)
    def ci(a): return (float(np.percentile(a,2.5)), float(np.percentile(a,97.5)))
    return {
        "b_hat": b0, "L_inf": L0, "A": A0, "yhat": yhat0,
        "b_ci": ci(b_s), "L_ci": ci(L_s), "A_ci": ci(A_s),
        "b_samps": np.array(b_s), "L_samps": np.array(L_s), "A_samps": np.array(A_s),
    }

def kstar_from_gain(yhat, frac=0.8):
    total = float(yhat[0]-yhat[-1])
    if total <= 0: return 9
    for i in range(len(yhat)):
        if (yhat[0]-yhat[i])/total >= frac:
            return i+1
    return 9

# -------------------------------
# 3) Load data & fit（与原逻辑一致）
# -------------------------------
PATH_3B = "results_llama_3B.csv"
PATH_8B = "results_llama_8B.csv"
k, y3 = load_series(PATH_3B)
_,  y8 = load_series(PATH_8B)

fit3 = bootstrap_fit(k, y3, B=1000, refit_b=True)
fit8 = bootstrap_fit(k, y8, B=1000, refit_b=True)
r2_3 = r2_score(y3, fit3["yhat"])
r2_8 = r2_score(y8, fit8["yhat"])

# Marginal gains（保持原定义）
d3 = np.r_[np.nan, fit3["yhat"][:-1] - fit3["yhat"][1:]]
d8 = np.r_[np.nan, fit8["yhat"][:-1] - fit8["yhat"][1:]]

k80_3, k90_3 = kstar_from_gain(fit3["yhat"], 0.80), kstar_from_gain(fit3["yhat"], 0.90)
k80_8, k90_8 = kstar_from_gain(fit8["yhat"], 0.80), kstar_from_gain(fit8["yhat"], 0.90)

# Save summary（保持不变）
os.makedirs("out", exist_ok=True)
pd.DataFrame([
    {"backbone":"LLaMA-3.2 3B", "b":fit3["b_hat"], "b_ci_lo":fit3["b_ci"][0], "b_ci_hi":fit3["b_ci"][1],
     "L_inf":fit3["L_inf"], "L_ci_lo":fit3["L_ci"][0], "L_ci_hi":fit3["L_ci"][1],
     "A":fit3["A"], "A_ci_lo":fit3["A_ci"][0], "A_ci_hi":fit3["A_ci"][1],
     "R2":r2_3, "kstar80":k80_3, "kstar90":k90_3},
    {"backbone":"LLaMA-3 8B", "b":fit8["b_hat"], "b_ci_lo":fit8["b_ci"][0], "b_ci_hi":fit8["b_ci"][1],
     "L_inf":fit8["L_inf"], "L_ci_lo":fit8["L_ci"][0], "L_ci_hi":fit8["L_ci"][1],
     "A":fit8["A"], "A_ci_lo":fit8["A_ci"][0], "A_ci_hi":fit8["A_ci"][1],
     "R2":r2_8, "kstar80":k80_8, "kstar90":k90_8},
]).to_csv("out/rq13_llama_fit_summary.csv", index=False)

# =============================
# 4) 绘图函数（风格同“代码2”）
# =============================
def plot_marginal_gain(k, d3, d8, tag3="LLaMA-3.2 3B", tag8="LLaMA-3 8B"):
    fig, ax = plt.subplots(figsize=(8.0, 5.0), constrained_layout=True)

    # 数据从 k=2 开始（与原逻辑一致：d[*] 第一个为 NaN）
    x = k[1:]
    # 数据散点
    s1 = ax.scatter(x, d3[1:], marker="o", label=f"{tag3} (Δ from fit)")
    s2 = ax.scatter(x, d8[1:], marker="s", label=f"{tag8} (Δ from fit)")
    # 连线以提高趋势可读性（原来就是折线）
    l1, = ax.plot(x, d3[1:], linestyle="-", alpha=0.9)
    l2, = ax.plot(x, d8[1:], linestyle="-", alpha=0.9)
    ax.set_xlabel("Number of merged experts $k$")
    ax.set_ylabel(r"$\Delta L(k)=L(k{-}1)-L(k)$")
    ax.set_title("Marginal gain vs. $k$")

    # 轴样式
    bold_spines(ax, width=2.4)
    ax.tick_params(direction="out", length=5.5, width=1.4)

    # 图例（边框开、位置右上角）
    leg = ax.legend(loc="upper right", ncol=1)
    for legline in leg.get_lines():
        legline.set_linewidth(2.5)

    fig.savefig("figs/rq13_llama_marginal_gain.png")
    plt.close(fig)

def plot_kstar_bars(k80_3, k90_3, k80_8, k90_8):
    fig, ax = plt.subplots(figsize=(7.6, 5.0), constrained_layout=True)

    names = ["LLaMA-3.2 3B", "LLaMA-3 8B"]
    k80 = [k80_3, k80_8]
    k90 = [k90_3, k90_8]
    idx = np.arange(2); w = 0.36

    b1 = ax.bar(idx - w/2, k80, width=w, label=r"$k^*_{80}$", edgecolor="#333333", linewidth=2.0)
    b2 = ax.bar(idx + w/2, k90, width=w, label=r"$k^*_{90}$", edgecolor="#333333", linewidth=2.0, color="#95cadb")
    ax.set_ylim(1, 11)
    ax.set_xticks(idx, names)
    ax.set_ylabel("$k$ to reach 80/90\\% of total gain")
    ax.set_title("How many experts are enough?")

    # 柱顶数值标注（不改变数据与刻度逻辑）
    for bars in (b1, b2):
        for rect in bars:
            h = rect.get_height()
            ax.annotate(f"{int(h)}",
                        xy=(rect.get_x() + rect.get_width()/2, h),
                        xytext=(0, 6),
                        textcoords="offset points",
                        ha="center", va="bottom", fontsize=12)

    bold_spines(ax, width=2.4)
    ax.tick_params(direction="out", length=5.5, width=1.4)

    leg = ax.legend(loc="upper right", ncol=1)
    for legpatch in leg.get_patches():
        legpatch.set_linewidth(1.0)

    fig.savefig("figs/rq13_llama_kstar_8090.png")
    plt.close(fig)

# -------------------------------
# 5) 生成两张图（保持原有输出文件名）
# -------------------------------
plot_marginal_gain(k, d3, d8, tag3="LLaMA-3.2 3B", tag8="LLaMA-3 8B")
plot_kstar_bars(k80_3, k90_3, k80_8, k90_8)

print("Done. See figs/ and out/ for outputs.")