# -*- coding: utf-8 -*-
# RQ4 @32B: 4 methods -> mean CE vs k, variance vs k (cross-domain macro)
# 风格对齐参考“绘图代码2”（画布大小、画框设置、线条格式、颜色），但严格保留原始数据与绘制逻辑：
# - 均值：拟合 y(k) = y_inf + M/(k+b)，并画散点+虚线拟合曲线
# - 方差：不拟合，连线（log 轴；仅绘图时对 0 加极小值）
# - I/O 路径与文件名保持一致

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter

os.makedirs("figs", exist_ok=True)

# -----------------------------
# 0) Utilities（拟合逻辑原样保留）
# -----------------------------
def _fit_linear_for_b(k, y, b):
    """对固定 b，最小二乘求解 y ≈ y_inf + M/(k+b)。"""
    x = 1.0 / (k + b)
    X = np.vstack([np.ones_like(x), x]).T
    coef, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
    y_inf, M = coef.tolist()
    y_hat = X @ coef
    sse = float(np.sum((y - y_hat) ** 2))
    return y_inf, M, y_hat, sse

def fit_floor_tail_mean(k, y, b_grid=np.linspace(0.0, 1.0, 101)):
    """均值拟合（不加约束）：扫描 b 以最小 SSE。"""
    best = None
    best_sse = np.inf
    best_pred = None
    for b in b_grid:
        y_inf, M, y_hat, sse = _fit_linear_for_b(k, y, b)
        if sse < best_sse:
            best_sse = sse
            best = (b, y_inf, M)
            best_pred = y_hat
    return best[0], best[1], best[2], best_pred, best_sse

# -----------------------------
# 1) Fallback data (32B cross-domain macro) 原样
# -----------------------------
k_vals = np.arange(1, 10, dtype=float)  # 1..9

fallback_mean_32B = {
    "AVG":  [0.517304885,0.487779532,0.470490538,0.464454014,0.459946078,0.457620518,0.454727274,0.453815997,0.452999324],
    "TA":  [0.5036,0.4821,0.4675,0.4641,0.4631,0.4601,0.4599,0.459,0.4586],
    "TIES":[0.4956623,0.479648,0.4655754,0.45924,0.45452,0.4511964,0.446945,0.444825,0.4427],
    "DARE":[0.5165,
            0.4885,
            0.4708,
            0.4644,
            0.4601,
            0.4580,
            0.4539,
            0.4543,
            0.4523],
}

fallback_var_32B = {
    "AVG":  [0.00098,0.000967,0.000967,0.000715,0.000708,0.000376,0.000122,0.000043],
    "TA":  [0.001063,0.000783,0.000783,0.000547,0.000502,0.000286,0.000099,0.000033],
    "TIES":[0.001091,0.00065,0.000616,0.000485,0.000517,0.000298,0.000107,0.000037],
    "DARE":[0.001021,
0.000981,
0.000947,
0.000737,
0.000704,
0.000372,
0.000124,
0.000042],
}

# -----------------------------
# 2) 读取 all_methods.xlsx（如存在），柔性解析列与方法名 — 原逻辑保留
# -----------------------------
xlsx_path = "all_methods.xlsx"

def _read_metric(sheet_name):
    try:
        df = pd.read_excel(xlsx_path, sheet_name=sheet_name)
    except Exception:
        return None
    df = df.rename(columns={c: str(c).strip().lower() for c in df.columns})
    col_aliases = {
        "n": ["n","model","model_size","size"],
        "method": ["method","algo","rule"],
        "k": ["k","experts","num","number"],
        "value": ["value","mean","ce","loss","var","variance"],
    }
    colmap = {}
    for canon, cands in col_aliases.items():
        for c in list(df.columns):
            if c in cands:
                colmap[c] = canon
    df = df.rename(columns=colmap)
    need = {"n","method","k","value"}
    if not need.issubset(set(df.columns)):
        return None

    df["method"] = df["method"].astype(str).str.strip()
    df["method"] = (df["method"]
                    .str.replace(r"ties\s*\(\s*0\.5\s*\)", "TIES", regex=True)
                    .str.replace(r"ties\s*\(\s*1(\.0)?\s*\)", "TIES(1.0)", regex=True)
                    .str.replace(r"^ties$", "TIES(1.0)", regex=True)
                    .str.replace(r"^ta$", "TA", regex=True)
                    .str.replace(r"AVG.*", "AVG", regex=True)
                    .str.replace(r"dare.*", "DARE", regex=True))
    return df

def _inject_from_df(df, target_dict, label):
    if df is None:
        return

    def _to_float32(x):
        try:
            return float(str(x).lower().replace("b","").strip())
        except Exception:
            return np.nan

    sub = df.copy()
    sub["n_num"] = sub["n"].apply(_to_float32)
    sub = sub[np.isfinite(sub["n_num"]) & np.isclose(sub["n_num"], 32.0)]
    if sub.empty:
        print(f"[warn] {label}: no N=32B rows found in xlsx; using fallbacks where needed.")
        return

    for m in ["AVG","TA","TIES","TIES(1.0)","DARE"]:
        cur = sub[sub["method"] == m]
        if cur.empty:
            continue
        series = cur.sort_values("k")["value"].to_numpy(dtype=float)
        if series.size >= 9:
            target_dict[m] = series[:9].tolist()
        else:
            print(f"[warn] {label}: {m} has only {series.size} points; need 9. Skipping.")

if os.path.exists(xlsx_path):
    df_mean = _read_metric("mean")
    df_var  = _read_metric("var")
    _inject_from_df(df_mean, fallback_mean_32B, "mean")
    _inject_from_df(df_var,  fallback_var_32B,  "var")

# -----------------------------
# 3) 选择需要绘制的方法（原样）
# -----------------------------
all_methods = ["AVG","TA","TIES","DARE"]
methods_for_mean = [m for m in all_methods if m in fallback_mean_32B]
methods_for_var  = [m for m in all_methods if m in fallback_var_32B]

if not methods_for_mean:
    raise RuntimeError("No mean series for N=32B were found.")
if not methods_for_var:
    raise RuntimeError("No variance series for N=32B were found.")

print("[info] Mean methods:", methods_for_mean)
print("[info] Var  methods:", methods_for_var)

# -----------------------------
# 4) 风格与颜色映射（对齐绘图代码2）
# -----------------------------
# Okabe–Ito 色弱友好调色盘
OKABE_ITO = ["#4087ad", "#9462af", "#ff5e7d", "#ffa600", "#8EC8ED", "#AED594", "#D693BE", "#F5B3A5"]

def _method_colors(methods, palette=None):
    palette = palette or OKABE_ITO
    methods_sorted = sorted(methods)  # 稳定字母序
    return {m: palette[i % len(palette)] for i, m in enumerate(methods_sorted)}

def _apply_ax_style(ax, *, xfmt=None, yfmt=None):
    # 坐标轴脊线：深色、加粗
    for sp in ax.spines.values():
        sp.set_color("#1f2a35")
        sp.set_linewidth(3.0)
    ax.tick_params(labelsize=17)
    if xfmt:
        ax.xaxis.set_major_formatter(FormatStrFormatter(xfmt))
    if yfmt:
        ax.yaxis.set_major_formatter(FormatStrFormatter(yfmt))

def _legend_style(ax):
    leg = ax.legend(
        title="Method",
        title_fontsize=17,
        fontsize=15,
        loc="best",
        frameon=True,
        fancybox=True,
        framealpha=0.95,
        borderpad=0.6,
        handlelength=1.8,
        handletextpad=0.6,
        labelspacing=0.4
    )
    if leg:
        leg.get_frame().set_edgecolor("#1f2a35")
        leg.get_frame().set_linewidth(1.0)
        leg.get_frame().set_facecolor("white")

# -----------------------------
# 5) 绘图封装（保持原逻辑；仅替换风格）
# -----------------------------
def _plot_points(ax, x, y, label, color):
    # 仅散点（与参考风格一致，统一圆点；逻辑不变）
    ax.plot(
        x, y, linestyle="", marker="o",
        linewidth=2.0, markersize=6,
        color=color, markerfacecolor=color,
        markeredgecolor="#3b3b3b", markeredgewidth=2.0,
        label=label, alpha=0.55, zorder=3
    )

def _plot_fit_mean(ax, k, y, label, color):
    # 均值：拟合 + 虚线曲线（逻辑保持）
    b, y_inf, M, y_hat, _ = fit_floor_tail_mean(k, y)
    k_smooth = np.linspace(min(k), max(k), 400)
    y_smooth = y_inf + M / (k_smooth + b)
    _plot_points(ax, k, y, label, color)
    ax.plot(
        k_smooth, y_smooth, linestyle="--",
        linewidth=3.0, color=color, alpha=0.55, zorder=2
    )

def _plot_var_polyline(ax, k, y, label, color):
    # 方差：不拟合，折线连点（逻辑保持）
    y_plot = np.array(y, dtype=float)
    y_plot[y_plot <= 0] = np.finfo(float).eps
    ax.plot(
        k[:-1], y_plot, linestyle="-", marker="o",
        linewidth=2.0, markersize=6,
        color=color, markerfacecolor=color,
        markeredgecolor="#3b3b3b", markeredgewidth=2.0,
        label=label, alpha=0.55, zorder=2
    )

# -----------------------------
# 6) 绘制与导出（画布大小/分辨率/标题/坐标标签风格统一）
# -----------------------------
# 颜色按方法稳定映射
color_map_mean = _method_colors(methods_for_mean, OKABE_ITO)
color_map_var  = _method_colors(methods_for_var,  OKABE_ITO)

# (a) Mean CE vs k （拟合）
fig1, ax1 = plt.subplots(figsize=(8.5, 6.0), dpi=300)
for m in methods_for_mean:
    y = np.array(fallback_mean_32B[m], dtype=float)
    _plot_fit_mean(ax1, k_vals, y, m, color_map_mean[m])

ax1.set_xlabel("Number of merged experts $k$", fontsize=16, labelpad=6)
ax1.set_ylabel("Cross-domain mean CE", fontsize=16, labelpad=6)
ax1.set_title("Mean CE vs. $k$ at 32B", fontsize=18, pad=8)
ax1.set_xticks(k_vals)
_apply_ax_style(ax1, yfmt="%.2f")
_legend_style(ax1)
fig1.tight_layout()
fig1_png = "figs/rq4_methods_32b_mean_vs_k.png"
fig1_pdf = "figs/rq4_methods_32b_mean_vs_k.pdf"
fig1.savefig(fig1_png, bbox_inches="tight", dpi=300)
fig1.savefig(fig1_pdf, bbox_inches="tight")
plt.close(fig1)
print("[saved]", fig1_png)
print("[saved]", fig1_pdf)

# (b) Variance vs k （不拟合，折线连点；log 轴保留）
fig2, ax2 = plt.subplots(figsize=(8.5, 6.0), dpi=300)
for m in methods_for_var:
    y = np.array(fallback_var_32B[m], dtype=float)
    _plot_var_polyline(ax2, k_vals, y, m, color_map_var[m])

ax2.set_xlabel("Number of merged experts $k$", fontsize=16, labelpad=6)
ax2.set_ylabel("Variance across merges", fontsize=16, labelpad=6)
ax2.set_title("Variance vs. $k$ at 32B", fontsize=18, pad=8)
ax2.set_xticks(k_vals[:-1])
ax2.set_yscale("log")  # 逻辑保留
_apply_ax_style(ax2, yfmt="%.1e")
_legend_style(ax2)
fig2.tight_layout()
fig2_png = "figs/rq4_methods_32b_var_vs_k.png"
fig2_pdf = "figs/rq4_methods_32b_var_vs_k.pdf"
fig2.savefig(fig2_png, bbox_inches="tight", dpi=300)
fig2.savefig(fig2_pdf, bbox_inches="tight")
plt.close(fig2)
print("[saved]", fig2_png)
print("[saved]", fig2_pdf)