# -*- coding: utf-8 -*-
# RQ13: Cross-backbone (LLaMA) external validation
# - Read results_llama_3B.csv, results_llama_8B.csv
# - Aggregate macro CE by merge order -> by k
# - Fit L(k)=L_inf + A/(k+b) with b grid search
# - Plot curve+fit, and parameter comparison
# ------------------------------------------------
import os, re, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -----------------------------
# I/O
# -----------------------------
CSV_3B = "results_llama_3B.csv"  # LLaMA-3.2 3B
CSV_8B = "results_llama_8B.csv"  # LLaMA-3   8B
os.makedirs("figs", exist_ok=True)
os.makedirs("out", exist_ok=True)

# -----------------------------
# Utilities
# -----------------------------
def find_ce_col(df):
    """Return the CE column name (first col containing 'CE' case-insensitive)."""
    for c in df.columns:
        if re.search(r'ce', str(c), flags=re.I):
            return c
    # fallback: try 'loss'
    for c in df.columns:
        if re.search(r'loss', str(c), flags=re.I):
            return c
    raise ValueError("Cannot find a CE/Loss column in the CSV.")

def parse_k_from_model(model_str):
    """
    'model' like '1-2-5-...' or '1,2,5' or '[1 2 5]' -> k = length.
    """
    s = str(model_str)
    # keep digits and separators
    toks = re.findall(r'\d+', s)
    return len(toks)

def macro_series_from_csv(path):
    """
    Input CSV rows: columns at least include
      - 'model' (merge order as string)
      - 'problem'/'domain' (evaluation domain)
      - 'CE Loss' (or similar)
    Returns:
      k_vals: np.array of unique k sorted
      Lk: mean CE at each k (macro-averaged across domains, then averaged across orders)
      Nk: number of orders contributing to each k
      per_k_orders: dict k -> list of per-order macro CE (for residual stats)
    """
    df = pd.read_csv(path)
    if 'model' not in df.columns:
        # try 'Model' or similar
        cand = [c for c in df.columns if c.lower() == 'model']
        if cand:
            df = df.rename(columns={cand[0]: 'model'})
        else:
            raise ValueError("CSV must contain a 'model' column.")

    ce_col = find_ce_col(df)  # e.g., 'CE Loss'
    # 1) Macro-avg CE per order (average over domains/problems)
    grp = df.groupby('model', as_index=False)[ce_col].mean()
    grp['k'] = grp['model'].apply(parse_k_from_model)

    # 2) For each k, average across orders
    per_k = grp.groupby('k')[ce_col].agg(['mean','count']).reset_index()
    k_vals = per_k['k'].to_numpy()
    Lk = per_k['mean'].to_numpy()
    Nk = per_k['count'].to_numpy()

    # also keep per-order lists for diagnostics
    per_k_orders = {}
    for k, sub in grp.groupby('k'):
        per_k_orders[int(k)] = sub[ce_col].to_list()

    # sort by k
    idx = np.argsort(k_vals)
    return k_vals[idx], Lk[idx], Nk[idx], {int(k): per_k_orders[int(k)] for k in k_vals[idx]}

def fit_floor_tail(k, y, b):
    """
    Fit L(k) = L_inf + A/(k+b) (for fixed b) via linear LS in [L_inf, A].
    Return L_inf, A, yhat, SSE.
    """
    k = np.asarray(k, float); y = np.asarray(y, float)
    x = 1.0/(k + float(b))
    X = np.vstack([np.ones_like(x), x]).T
    coef, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
    L_inf, A = float(coef[0]), float(coef[1])
    yhat = X @ coef
    sse = float(np.sum((y - yhat)**2))
    return L_inf, A, yhat, sse

def choose_b_grid(k, y, b_grid=np.linspace(0.0, 1.0, 81)):
    """Pick b in [0,1] minimizing SSE."""
    best = (None, np.inf, None, None)  # b, sse, L_inf, A
    for b in b_grid:
        L_inf, A, yhat, sse = fit_floor_tail(k, y, b)
        if sse < best[1]:
            best = (float(b), sse, L_inf, A)
    b_star, _, L_inf, A = best
    _, _, yhat, _ = fit_floor_tail(k, y, b_star)
    return b_star, L_inf, A, yhat

def r2_score(y, yhat):
    y = np.asarray(y, float); yhat = np.asarray(yhat, float)
    ss_res = np.sum((y - yhat)**2)
    ss_tot = np.sum((y - np.mean(y))**2)
    return 1.0 - ss_res/ss_tot if ss_tot > 0 else 1.0

def rel_improve(y):
    """Relative improvement from k=1 to k=max as a percentage."""
    y = np.asarray(y, float)
    return (y[0] - y[-1]) / y[0] * 100.0

# -----------------------------
# Load both backbones and fit
# -----------------------------
def process_backbone(tag, csv_path):
    k, Lk, Nk, per_k_orders = macro_series_from_csv(csv_path)
    b, L_inf, A, yhat = choose_b_grid(k, Lk)
    R2 = r2_score(Lk, yhat)
    resid = np.abs(Lk - yhat)
    max_resid = float(np.max(resid))
    rel_drop = rel_improve(Lk)
    return {
        "tag": tag, "k": k, "Lk": Lk, "Nk": Nk,
        "b": b, "L_inf": L_inf, "A": A, "yhat": yhat,
        "R2": R2, "max_resid": max_resid, "rel_drop_%": rel_drop,
        "per_k_orders": per_k_orders
    }

res_3B = process_backbone("LLaMA-3.2 3B", CSV_3B)
res_8B = process_backbone("LLaMA-3 8B",  CSV_8B)

# Save summary CSV
summary_rows = []
for r in [res_3B, res_8B]:
    summary_rows.append({
        "backbone": r["tag"],
        "R2": r["R2"], "b": r["b"],
        "L_inf": r["L_inf"], "A": r["A"],
        "L(k=1)": r["Lk"][0],
        "L(k=max)": r["Lk"][-1],
        "relative_drop_%": r["rel_drop_%"],
        "max_abs_residual": r["max_resid"]
    })
pd.DataFrame(summary_rows).to_csv("out/rq13_llama_fit_summary.csv", index=False)
print(pd.DataFrame(summary_rows))

# -----------------------------
# Figure (a): curve + fit
# -----------------------------
# =============================
# 美化风格（新增）
# =============================
from matplotlib import cycler

def apply_global_style():
    plt.rcParams.update({
        # 字体与字号
        "font.size": 14,
        "axes.titlesize": 17,       # 标题字号（更大）
        "axes.labelsize": 14,       # 轴标签字号（更大）
        "xtick.labelsize": 14,
        "ytick.labelsize": 14,
        "legend.fontsize": 13,
        # 线条 & 网格
        "lines.linewidth": 2.2,
        "lines.markersize": 7.5,
        "axes.grid": False,
        "grid.linestyle": "--",
        "grid.alpha": 0.25,
        "grid.linewidth": 1.0,
        # 图例与边距
        "legend.frameon": True,
        "legend.framealpha": 0.95,
        "legend.edgecolor": "#cccccc",
        "figure.dpi": 300,
        "savefig.dpi": 300,
    })
    # 配色循环（两条主线：数据点/拟合线将用相同主色的不同明度/透明度）
    base_colors = ["#4C78A8", "#F58518", "#54A24B", "#E45756", "#72B7B2"]
    plt.rcParams["axes.prop_cycle"] = cycler(color=base_colors)

def bold_spines(ax, width=2.5):
    for side in ["left", "right", "bottom", "top"]:
        ax.spines[side].set_linewidth(width)

apply_global_style()

# =============================
# Figure (a): curve + fit（美化）
# =============================
fig, ax = plt.subplots(figsize=(8, 5), constrained_layout=True)

# 为两条主线定义配色（与 rc 循环保持一致，但手动控制明暗/透明度用于区分数据点与拟合线）
c1 = "#4C78A8"  # 3B 主色
c2 = "#ea6aa8"  # 8B 主色

# 数据线（点带线）
ax.scatter(res_3B["k"], res_3B["Lk"], marker="o", 
        label=f"{res_3B['tag']} data", color=c1)
ax.scatter(res_8B["k"], res_8B["Lk"], marker="s", 
        label=f"{res_8B['tag']} data", color=c2)

# 拟合线（同色系更浅/虚线）
ax.plot(res_3B["k"], res_3B["yhat"], linestyle="--", linewidth=2.4,
        label=f"{res_3B['tag']} fit (R$^2$={res_3B['R2']:.3f})",
        color=c1, alpha=0.9, dash_capstyle="round")
ax.plot(res_8B["k"], res_8B["yhat"], linestyle="--", linewidth=2.4,
        label=f"{res_8B['tag']} fit (R$^2$={res_8B['R2']:.3f})",
        color=c2, alpha=0.9, dash_capstyle="round")

ax.set_xlabel("Number of merged experts $k$", fontsize=17)
ax.set_ylabel("Macro-averaged CE", fontsize=15)
ax.set_title("Open-source backbones: macro CE vs. $k$ with floor+tail fits")

# 轴外框加粗
bold_spines(ax, width=2.5)

# 刻度方向/长度更易读
ax.tick_params(direction="out", length=5.5, width=1.4)

# 图例放置在右下角并带边框
leg = ax.legend(loc="upper right", ncol=1)
for legline in leg.get_lines():
    legline.set_linewidth(2.5)

fig.savefig("figs/rq13_llama_curve_fit.png")
plt.close(fig)
