# poolsize_domains_fit_plots_selective.py
# ------------------------------------------------------------
# 目标：
#   1) 限制候选 expert 池为 M=7 或 M=5
#   2) 仅使用“你手动选择的 7/5 个 domain”参与拟合与作图
#   3) 对每个保留的领域、每个 N 拟合 L(k)=L_inf(N)+A(N)/(k+b_d)
#      -> 输出 per-domain 的 L_inf(N), A(N) 点与拟合曲线（两面板：左 L_inf，右 A）
# 依赖：numpy, pandas, matplotlib
# 输入：results_dare_*.csv
# 输出：两张 PDF：
#   - rq_poolM7_selDomains_Linf_A_vs_N.pdf
#   - rq_poolM5_selDomains_Linf_A_vs_N.pdf
# 以及拟合表：
#   - poolM7_selDomains_per_domain_fit_table.csv
#   - poolM5_selDomains_per_domain_fit_table.csv
# ------------------------------------------------------------
import os, re, glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator, FuncFormatter, NullFormatter

# ============ 全局绘图风格（新增） ============
plt.rcParams.update({
    "figure.dpi": 150,
    "savefig.dpi": 300,
    "font.size": 10,          # 基础字号
    "axes.titlesize": 12,     # 子图标题
    "axes.labelsize": 11,     # 坐标轴标签
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.frameon": True,
    "legend.fontsize": 10,
    "axes.edgecolor": "black",
    "axes.linewidth": 1.4,    # 坐标轴边框默认加粗
})


def _beautify_axes(ax):
    # 1) 边框 & 网格 & 刻度样式
    for spine in ax.spines.values():
        spine.set_linewidth(1.6)
    ax.tick_params(which="both", width=1.3, length=4)
    ax.grid(True, which="major", alpha=0.18, axis="y")

    # 2) 只显示“稀疏”的主刻度（1, 2, 5 倍数），并统一成两位小数
    ax.yaxis.set_major_locator(LogLocator(base=10.0, subs=[1.0, 2.0, 5.0], numticks=8))
    ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _:
                                               f"{y:.2f}" if np.isfinite(y) and y > 0 else ""))

    # 3) 设置次刻度但不显示其标签，避免混入科学计数法
    ax.yaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(1, 10) * 0.1))
    ax.yaxis.set_minor_formatter(NullFormatter())
    
# ============ 你只需要改这里 ============
SELECT_DOMAINS_M7 = [
    "algebra", "analysis", "geometry", "code", "number_theory", "chemistry", "physics"
]
SELECT_DOMAINS_M8 = [
    "algebra", "analysis", "chemistry", "physics"
]

BASE_DIR = "DARE"
OUT_DIR  = "7domains/5_and_7_outputs"
os.makedirs(OUT_DIR, exist_ok=True)

REP_DOMAINS_FOR_A_M7 = []  # 为空=全部
REP_DOMAINS_FOR_A_M8 = []

# 拟合时的 b 网格
B_GRID = np.linspace(0.0, 1.0, 41)  # [0, 1] 步长 0.025

# ============ 工具函数 ============
def parse_expert_ids(model_str: str):
    toks = re.split(r'[-+]', str(model_str))
    ids = []
    for t in toks:
        t = t.strip()
        if not t:
            continue
        try:
            ids.append(int(float(t)))
        except Exception:
            pass
    return ids

def add_meta(df: pd.DataFrame) -> pd.DataFrame:
    exps, ks = [], []
    for m in df["model"]:
        ids = parse_expert_ids(m)
        exps.append(ids)
        ks.append(len(ids))
    out = df.copy()
    out["experts"] = exps
    out["k"] = ks
    return out

def read_all_dare_csvs(base_dir):
    paths = sorted(glob.glob(os.path.join(base_dir, "results_dare_*B.csv")))
    if not paths:
        raise FileNotFoundError("未找到 results_dare_*B.csv，请检查 BASE_DIR。")
    data = {}
    for p in paths:
        m = re.search(r"results_dare_([0-9.]+)B\.csv", os.path.basename(p))
        if not m:
            continue
        N = float(m.group(1))
        df = pd.read_csv(p)
        df = df[df["class"].astype(str) != "overall"].copy()
        df = add_meta(df)
        data[N] = df
    return dict(sorted(data.items()))

def filter_by_pool_and_domains(dfN: pd.DataFrame, M: int, domains_keep: list) -> pd.DataFrame:
    allow = set(range(1, M+1))
    mask_experts = dfN["experts"].apply(lambda ids: all((x in allow) for x in ids))
    mask_domains = dfN["class"].isin(domains_keep)
    return dfN[mask_experts & mask_domains].copy()

def fit_LA_for_b(k, y, b):
    x = 1.0/(k + b)
    X = np.vstack([np.ones_like(x), x]).T
    coef, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
    L_inf, A = float(coef[0]), float(coef[1])
    yhat = (X @ coef).astype(float)
    ss_res = float(np.sum((y - yhat)**2))
    ss_tot = float(np.sum((y - np.mean(y))**2))
    R2 = 1.0 - ss_res/ss_tot if ss_tot > 0 else 1.0
    return L_inf, A, R2, yhat, ss_res

def choose_global_b_for_domain(across_N_k, b_grid=B_GRID):
    best_b, best_sse = None, np.inf
    for b in b_grid:
        total = 0.0
        ok = True
        for (N, k_arr, y_arr) in across_N_k:
            if len(k_arr) < 3:
                ok = False; break
            _, _, _, _, sse = fit_LA_for_b(k_arr, y_arr, b)
            total += sse
        if ok and total < best_sse:
            best_sse = total
            best_b = b
    return best_b

def fit_power_A(Ns, Avals):
    N = np.array(Ns, float)
    y = np.array(Avals, float)
    y = np.maximum(y, 1e-12)
    X = np.vstack([np.ones_like(N), -np.log(N)]).T
    coefs, _, _, _ = np.linalg.lstsq(X, np.log(y), rcond=None)
    logA0, gamma = coefs
    A0 = np.exp(logA0)
    yhat = A0 * N**(-gamma)
    ss_res = float(np.sum((y - yhat)**2))
    ss_tot = float(np.sum((y - np.mean(y))**2))
    R2 = 1.0 - ss_res/ss_tot if ss_tot > 0 else 1.0
    return A0, float(gamma), R2, yhat

def fit_floor_Linf(Ns, Lvals):
    N = np.array(Ns, float)
    y = np.array(Lvals, float)
    ymin = float(np.min(y))
    Lstar_grid = np.linspace(0.0, max(1e-6, ymin*0.999), 300)
    best = None; best_sse = np.inf; best_pred = None
    for Lstar in Lstar_grid:
        resid = y - Lstar
        if np.any(resid <= 0):
            continue
        X = np.vstack([np.ones_like(N), -np.log(N)]).T
        coefs, _, _, _ = np.linalg.lstsq(X, np.log(resid), rcond=None)
        logB, beta = coefs
        B = np.exp(logB)
        yhat = Lstar + B * N**(-beta)
        sse = float(np.sum((y - yhat)**2))
        if sse < best_sse:
            best_sse = sse
            best = (float(Lstar), float(B), float(beta))
            best_pred = yhat
    if best is None:
        X = np.vstack([np.ones_like(N), -np.log(N)]).T
        coefs, _, _, _ = np.linalg.lstsq(X, np.log(y), rcond=None)
        logB, beta = coefs
        B = np.exp(logB)
        best = (0.0, float(B), float(beta))
        best_pred = B * N**(-beta)
    ss_res = float(np.sum((y - best_pred)**2))
    ss_tot = float(np.sum((y - np.mean(y))**2))
    R2 = 1.0 - ss_res/ss_tot if ss_tot > 0 else 1.0
    return best[0], best[1], best[2], R2, best_pred

# ============ 主流程：对 M in {7,5} 逐域拟合并作图 ============

# def _beautify_axes(ax):
#     """统一坐标轴外观：加粗边框、网格与刻度、两位小数标签（即便是 log）"""
#     for spine in ax.spines.values():
#         spine.set_linewidth(1.6)          # 再加深一点
#     ax.tick_params(which="both", width=1.3, length=4)
#     ax.grid(True, which="major", alpha=0.18)
#     # y 轴两位小数（注意 log 刻度也按数值格式化）
#     ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _:
#                                                f"{y:.2f}" if np.isfinite(y) else ""))

def process_and_plot_for_pool(M, selected_domains, rep_domains_for_A, data_byN, out_pdf_path):
    Ns = sorted(data_byN.keys())
    all_domains_avail = sorted(set(data_byN[Ns[0]]["class"].unique()))
    missing = [d for d in selected_domains if d not in all_domains_avail]
    if missing:
        print(f"[WARN] 选中的 domain 在数据中找不到：{missing}")
    domains_use = [d for d in selected_domains if d in all_domains_avail]
    if len(domains_use) == 0:
        raise ValueError("选中的 domains 在数据中均不存在，请检查 SELECT_DOMAINS_*。")
    print(f"[INFO] M={M} 使用的 domains：{domains_use}")

    domain_curves = {d: [] for d in domains_use}
    for N, dfN in data_byN.items():
        dfN_M = filter_by_pool_and_domains(dfN, M=M, domains_keep=domains_use)
        grp = dfN_M.groupby(["class", "k"])["CE Loss"].mean().reset_index()
        for d in domains_use:
            sub = grp[grp["class"] == d].sort_values("k")
            if len(sub) >= 3:
                domain_curves[d].append((N,
                    sub["k"].to_numpy(float),
                    sub["CE Loss"].to_numpy(float)))

    domain_b = {}
    for d, series in domain_curves.items():
        if len(series) == 0:
            continue
        b_d = choose_global_b_for_domain(series, b_grid=B_GRID)
        domain_b[d] = b_d

    fit_rows = []
    L_points = {}
    A_points = {}

    for d, series in domain_curves.items():
        if d not in domain_b:
            continue
        b = domain_b[d]
        Ns_d, Ls_d, As_d = [], [], []
        for (N, k_arr, y_arr) in sorted(series, key=lambda t: t[0]):
            L_inf, A, _, _, _ = fit_LA_for_b(k_arr, y_arr, b)
            Ns_d.append(N); Ls_d.append(L_inf); As_d.append(max(A, 1e-12))
        Ns_d = np.array(Ns_d, float)
        Ls_d = np.array(Ls_d, float)
        As_d = np.array(As_d, float)

        A0, gamma, R2_A, Ahat = fit_power_A(Ns_d, As_d)
        Lstar, B, beta, R2_L, Lhat = fit_floor_Linf(Ns_d, Ls_d)

        fit_rows.append({
            "domain": d, "b_hat": b,
            "A0": A0, "gamma": gamma, "R2_A": R2_A,
            "Lstar": Lstar, "B": B, "beta": beta, "R2_L": R2_L,
            "Nmin": float(min(Ns_d)), "Nmax": float(max(Ns_d)),
            "nN": int(len(Ns_d))
        })
        L_points[d] = (Ns_d, Ls_d, Lhat)
        A_points[d] = (Ns_d, As_d, Ahat)

    fit_df = pd.DataFrame(fit_rows).sort_values("domain")
    fit_df.to_csv(os.path.join(OUT_DIR, os.path.basename(out_pdf_path).replace(".pdf", "_fit_table.csv")), index=False)
    print(f"[INFO] Saved per-domain fit table -> {os.path.basename(out_pdf_path).replace('.pdf', '_fit_table.csv')}")

    # 选哪些 domain 画 A(N)
    if rep_domains_for_A:
        show_A = [d for d in rep_domains_for_A if d in A_points]
    else:
        show_A = list(A_points.keys())

    # ========== 作图：两面板 + 共享图例 ==========
    fig = plt.figure(figsize=(9.2, 4.2))   # 略放大留空间给底部 legend
    ax1 = fig.add_subplot(1, 2, 1)
    for d in sorted(L_points.keys()):
        Ns_d, Ls_d, Lhat = L_points[d]
        ax1.scatter(Ns_d, Ls_d, marker="o", s=22, label=f"{d} est.")
        ax1.plot(Ns_d, Lhat, linestyle="--", linewidth=1.6, label=f"{d} fit")
    ax1.set_xscale("log"); ax1.set_yscale("log")
    ax1.set_xlabel("Model size $N$ (B params, log)")
    ax1.set_ylabel("$L_\\infty(N)$ (log)")
    ax1.set_title(f"(a) $L_\\infty(N)$ across selected domains (M={M})")
    _beautify_axes(ax1)

    ax2 = fig.add_subplot(1, 2, 2)
    for d in sorted(show_A):
        Ns_d, As_d, Ahat = A_points[d]
        ax2.scatter(Ns_d, As_d, marker="s", s=22, label=f"{d} est.")
        ax2.plot(Ns_d, Ahat, linestyle="--", linewidth=1.6, label=f"{d} fit")
    ax2.set_xscale("log"); ax2.set_yscale("log")
    ax2.set_xlabel("Model size $N$ (B params, log)")
    ax2.set_ylabel("$A(N)$ (log)")
    ax2.set_title(f"(b) $A(N)$ across selected domains (M={M})")
    _beautify_axes(ax2)

    # ======= 共享 legend（去重）并放到底部 =======
    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    by_label = {}
    for h, l in list(zip(h1, l1)) + list(zip(h2, l2)):
        if l not in by_label:
            by_label[l] = h
    ncols = min(6, max(1, int(np.ceil(len(by_label)/2))))  # 自适应列数
    fig.legend(by_label.values(), by_label.keys(),
               loc="lower center", ncol=ncols, frameon=True, borderaxespad=0.6)

    # 总标题与排版
    fig.suptitle(
        f"Per-domain fits with restricted candidate pool M={M}\nSelected domains: {', '.join(sorted(L_points.keys()))}",
        y=0.9, fontsize=15
    )
    # 留出底部空间给共享 legend
    fig.tight_layout(rect=[0.02, 0.10, 0.98, 0.93])  # [left, bottom, right, top]
    fig.subplots_adjust(wspace=0.25)

    fig.savefig(out_pdf_path, bbox_inches="tight")
    plt.close(fig)
    print(f"[INFO] Saved figure -> {out_pdf_path}")

    return fit_df, show_A

# ============ 执行 ============
if __name__ == "__main__":
    data_byN = read_all_dare_csvs(BASE_DIR)

    # M=7
    out7 = os.path.join(OUT_DIR, "rq_poolM7_selDomains_Linf_A_vs_N.png")
    fitM7, showA7 = process_and_plot_for_pool(
        M=7,
        selected_domains=SELECT_DOMAINS_M7,
        rep_domains_for_A=REP_DOMAINS_FOR_A_M7,
        data_byN=data_byN,
        out_pdf_path=out7
    )

    # M=8
    out8 = os.path.join(OUT_DIR, "rq_poolM8_selDomains_Linf_A_vs_N.png")
    fitM8, showA8 = process_and_plot_for_pool(
        M=8,
        selected_domains=SELECT_DOMAINS_M8,
        rep_domains_for_A=REP_DOMAINS_FOR_A_M8,
        data_byN=data_byN,
        out_pdf_path=out8
    )

    print("\n[INFO] A(N) panel domains：")
    print(f"  M=7 -> {showA7}")
    print(f"  M=8 -> {showA8}")
    print(f"\n[DONE] Outputs saved to: {OUT_DIR}")