# -*- coding: utf-8 -*-
"""
按统一配置风格绘图（单图、matplotlib-only、无显式颜色）
- 复用既有拟合与统计逻辑
- 仅保留“图1：L_inf(N)”
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from matplotlib.lines import Line2D  # 用于 legend 代理项

# ===============================
# 0) 统一绘图配置（可直接替换为你给出的版本）
# ===============================
plot_config = {
    "figsize": (8.5, 6.0),
    "dpi": 300,
    "linewidth": 2.5,
    "markersize": 7,
    "markeredgewidth": 2.5,
    "xlabel": {"fontsize": 16, "labelpad": 6},
    "ylabel": {"fontsize": 16, "labelpad": 6},
    "title": {"fontsize": 18, "pad": 8},
    "tick_params": {"labelsize": 15},
    "legend": {
        "title_fontsize": 15,
        "fontsize": 12,
        "loc": "best",
        "frameon": True,
        "fancybox": True,
        "framealpha": 0.95,
        "borderpad": 0.6,
        "handlelength": 1.8,
        "handletextpad": 0.6,
        "labelspacing": 0.4,
        "edgecolor": "#1f2a35",
        "linewidth": 1.0,
        "facecolor": "white"
    },
    "spines": {"color": "#1f2a35", "linewidth": 3.0},
    "alpha": 0.95,
    "zorder": 2,
    "yfmt": "%.02f",
}

DOMAIN_COLORS = [
    "#008A45", "#468BCA", "#5F5F5E", "#7DD2F6", "#80C5A2",
    "#B384BA", "#D9C2DD", "#F27873", "#FFD373"
]

# ===============================
# 0.1) 其它运行参数
# ===============================
EARLY_K_RANGE = (2, 5)                         # 使用 k ∈ [2..5] 估计 A(N)
SELECTED_DOMAINS = ['algebra', 'chemistry']    # 仅用于示例/检查（本版本不绘制图2）
OUT_DIR = 'figs'                                # 输出目录（自动创建）

# ===============================
# 1) 数据
# ===============================
k_vals = np.arange(1, 10)  # 1..9
Ns = [0.5, 1.5, 3, 7, 14, 32]

avg_data = {  # ←←← 在此处粘贴你提供的完整大字典
    0.5: {"algebra":[0.38870969,0.361388945,0.353070365,0.348855826,0.346396314,0.344687186,0.343482999,0.342608706,0.341813997],
          "discrete":[0.675279706,0.638606598,0.627139191,0.621347593,0.617949222,0.615629954,0.613976834,0.612764024,0.611765409],
          "analysis":[0.400981652,0.373791342,0.365430147,0.361273265,0.358729691,0.357035066,0.355779848,0.354866396,0.35404844],
          "geometry":[0.466163542,0.436566282,0.427406096,0.42283913,0.420065942,0.418225191,0.416887299,0.415889587,0.415204736],
          "code":[0.631458418,0.602949613,0.591226322,0.584666522,0.580381321,0.577346809,0.574992322,0.573220679,0.571673265],
          "number_theory":[0.497854733,0.46661246,0.4569785,0.452169459,0.449364407,0.447499541,0.446133188,0.445094679,0.444341392],
          "chemistry":[1.339087142,1.271470311,1.250391469,1.240096755,1.233782558,1.229949244,1.226851675,1.224914202,1.22309423],
          "physics":[1.334634923,1.270568665,1.251434,1.242290259,1.23670613,1.233275433,1.230599054,1.228866834,1.227004083],
          "biology":[1.609473385,1.52554396,1.498841253,1.485587277,1.477488826,1.472506062,1.468362583,1.465754065,1.463440771]},
    1.5: {"algebra":[0.315373273,0.295693087,0.292465461,0.288742222,0.285844074,0.281397797,0.278885341,0.277676539,0.276707406],
          "discrete":[0.548971303,0.51829583,0.514443422,0.510867107,0.506271037,0.501898826,0.499072908,0.497471617,0.496613025],
          "analysis":[0.320008975,0.300520823,0.297101809,0.293736669,0.290960189,0.286097292,0.283669042,0.282558608,0.281585139],
          "geometry":[0.381928173,0.361264737,0.354469493,0.350939994,0.34759311,0.34447938,0.341947349,0.341449933,0.34123044],
          "code":[0.496565059,0.472588686,0.469751974,0.469487988,0.465333319,0.458521594,0.459409186,0.460353503,0.458793494],
          "number_theory":[0.40562323,0.380090262,0.377454744,0.373362826,0.369605868,0.364911672,0.362885013,0.361347131,0.360983564],
          "chemistry":[1.081343367,1.045566615,1.004760939,0.993294408,0.987886666,0.992343439,0.987540454,0.9873462,0.985602356],
          "physics":[1.088181642,1.05316764,1.018279327,1.008465584,1.004175439,1.006993868,1.003443619,1.002126369,1.000615201],
          "biology":[1.296567972,1.251977601,1.198715438,1.185429408,1.181860986,1.185602747,1.178655365,1.176602818,1.176086816]},
    3.0: {"algebra":[0.293407034,0.271431433,0.265069254,0.261705676,0.259903857,0.259435438,0.258632178,0.257932576,0.257598105],
          "discrete":[0.515105431,0.483781504,0.474643103,0.469935137,0.467344589,0.466593629,0.465419291,0.464394998,0.463808979],
          "analysis":[0.297353593,0.275700087,0.269388127,0.266050589,0.264278796,0.263754787,0.262925032,0.262268188,0.261830373],
          "geometry":[0.355216464,0.331970786,0.325250531,0.321728974,0.319887912,0.319428992,0.318623887,0.317852497,0.317383795],
          "code":[0.461050067,0.442670287,0.435650744,0.431110488,0.428018952,0.426525122,0.425159842,0.423967722,0.422765287],
          "number_theory":[0.377820405,0.351564821,0.343931368,0.339920515,0.337873004,0.337196617,0.336171771,0.335377243,0.334936368],
          "chemistry":[1.002853285,0.941914711,0.925199377,0.916051342,0.911097762,0.908694527,0.906686638,0.904922208,0.903670478],
          "physics":[1.01076463,0.953775111,0.937781996,0.929401956,0.924956043,0.922709849,0.920478335,0.918778389,0.917526107],
          "biology":[1.19694596,1.122303113,1.100536454,1.088636071,1.0817045,1.078472269,1.075787083,1.07308901,1.071046765]},
    7.0: {"algebra":[0.288204221,0.26204617,0.256643353,0.253969777,0.252547993,0.251413784,0.250701174,0.250128568,0.249726416],
          "discrete":[0.479819591,0.441345967,0.432518893,0.428181963,0.425758053,0.423920146,0.422737615,0.421773007,0.421163732],
          "analysis":[0.290004928,0.264157036,0.258188173,0.255194149,0.253549005,0.252328003,0.251520661,0.250857652,0.250593043],
          "geometry":[0.345914588,0.317164791,0.310572916,0.307261623,0.30543702,0.304063303,0.303222131,0.302535498,0.302139551],
          "code":[0.441199812,0.417474524,0.411645293,0.408700994,0.406208481,0.404496178,0.403422294,0.40247628,0.401894785],
          "number_theory":[0.361896146,0.329707016,0.322632794,0.31917792,0.317454128,0.316014357,0.31503946,0.314231925,0.313800938],
          "chemistry":[0.945744303,0.877591883,0.860481984,0.851888154,0.846596452,0.843013343,0.840473939,0.838687015,0.837399648],
          "physics":[0.963366772,0.901523203,0.886857941,0.879699645,0.875654743,0.872681452,0.870705264,0.86922001,0.868353373],
          "biology":[1.129518997,1.047706831,1.025895446,1.014941959,1.008303406,1.003661735,1.000331232,0.998025284,0.996190122]},
    14.0: {"algebra":[0.272534195,0.238723162,0.232248612,0.229093686,0.225822344,0.222602502,0.22076657,0.219750601,0.219655824],
           "discrete":[0.440893491,0.391182109,0.383570543,0.378847212,0.373675607,0.369465387,0.366878352,0.365555448,0.364534252],
           "analysis":[0.274840045,0.243046082,0.237165347,0.234153955,0.23093587,0.228131163,0.22637307,0.225542161,0.22523505],
           "geometry":[0.324313991,0.288687048,0.278685036,0.27490936,0.270873665,0.26849483,0.266268375,0.265813763,0.265783367],
           "code":[0.39472054,0.367021107,0.363279975,0.363190447,0.358321606,0.350160098,0.350872193,0.351812924,0.349731741],
           "number_theory":[0.342765584,0.301169614,0.295042737,0.29067375,0.286605786,0.282855102,0.281269266,0.280003589,0.279662677],
           "chemistry":[0.843074279,0.784809081,0.746377504,0.732448491,0.725489879,0.728018016,0.721253261,0.719769104,0.716085907],
           "physics":[0.8636146,0.801941284,0.770417437,0.757907068,0.752521354,0.75400138,0.748924754,0.746710994,0.74285281],
           "biology":[1.015637891,0.942581468,0.892687265,0.876038928,0.870265925,0.87056405,0.860771824,0.856969113,0.853540643]},
    32.0: {"algebra":[0.250056586,0.22819991,0.224164211,0.22059183,0.217122109,0.212983049,0.211003879,0.20998448,0.209667327],
           "discrete":[0.399170451,0.366727561,0.360532187,0.356714675,0.351094845,0.3467244,0.344268693,0.342847831,0.342680329],
           "analysis":[0.249342815,0.228952177,0.224928463,0.221872451,0.218602385,0.214221892,0.212230905,0.21125327,0.210953639],
           "geometry":[0.29852047,0.275523871,0.267949879,0.264155628,0.259822437,0.256552903,0.254432908,0.253909556,0.253894347],
           "code":[0.365842494,0.347083186,0.345486508,0.345554432,0.341585446,0.335027459,0.33629572,0.337481855,0.336340502],
           "number_theory":[0.313214458,0.286150663,0.281955999,0.277909175,0.273475474,0.269118212,0.267366016,0.266089372,0.265844435],
           "chemistry":[0.79198143,0.754100421,0.716174264,0.704870143,0.699072206,0.70178825,0.696492805,0.695831147,0.693996519],
           "physics":[0.808904331,0.771876999,0.739921357,0.730458777,0.725698069,0.727333029,0.723580858,0.72204132,0.720743992],
           "biology":[0.955801372,0.91056614,0.862690604,0.848422185,0.84358317,0.844337765,0.836054734,0.833301106,0.832135028]},
}
domains = list(avg_data[0.5].keys())
color_map = {d: DOMAIN_COLORS[i % len(DOMAIN_COLORS)] for i, d in enumerate(domains)}

# ===============================
# 2) 拟合函数（原样保留）
# ===============================
def fit_LA_for_b(k: np.ndarray, y: np.ndarray, b: float):
    """ y ≈ L_inf + A / (k + b) """
    x = 1.0 / (k + b)
    X = np.vstack([np.ones_like(x), x]).T
    coef, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
    L_inf, A = coef[0], coef[1]
    y_hat = X @ coef
    sse = float(np.sum((y - y_hat) ** 2))
    return L_inf, A, sse

def fit_power_law_A(Ns_list, Avals_list):
    """ A(N) = A0 * N^{-gamma} """
    N = np.asarray(Ns_list, dtype=float)
    y = np.maximum(np.asarray(Avals_list, dtype=float), 1e-12)
    X = np.vstack([np.ones_like(N), -np.log(N)]).T
    coef, _, _, _ = np.linalg.lstsq(X, np.log(y), rcond=None)
    A0 = float(np.exp(coef[0])); gamma = float(coef[1])
    yhat = A0 * (N ** (-gamma))
    ss_res = float(np.sum((y - yhat) ** 2))
    ss_tot = float(np.sum((y - np.mean(y)) ** 2))
    r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 1.0
    return A0, gamma, r2, yhat

def fit_floor_Linf(Ns_list, Lvals_list):
    """ L_inf(N) = L* + B * N^{-beta}，网格搜索 L* """
    N = np.asarray(Ns_list, dtype=float)
    y = np.asarray(Lvals_list, dtype=float)
    ymin = float(np.min(y))
    Lstar_grid = np.linspace(0.0, max(1e-6, ymin * 0.999), 300)

    best = None; best_pred = None; best_sse = np.inf
    for Lstar in Lstar_grid:
        residual = y - Lstar
        if np.any(residual <= 0):
            continue
        X = np.vstack([np.ones_like(N), -np.log(N)]).T
        coef, _, _, _ = np.linalg.lstsq(X, np.log(residual), rcond=None)
        B = float(np.exp(coef[0])); beta = float(coef[1])
        yhat = Lstar + B * N ** (-beta)
        sse = float(np.sum((y - yhat) ** 2))
        if sse < best_sse:
            best_sse = sse; best_pred = yhat; best = (Lstar, B, beta)

    if best is None:
        X = np.vstack([np.ones_like(N), -np.log(N)]).T
        coef, _, _, _ = np.linalg.lstsq(X, np.log(y), rcond=None)
        B = float(np.exp(coef[0])); beta = float(coef[1]); Lstar = 0.0
        best_pred = B * N ** (-beta); best = (Lstar, B, beta)
        best_sse = float(np.sum((y - best_pred) ** 2))

    ss_res = best_sse
    ss_tot = float(np.sum((y - np.mean(y)) ** 2))
    r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 1.0
    return best[0], best[1], best[2], r2, best_pred

# ===============================
# 2.1) 绘图辅助：按配置应用样式
# ===============================
def _apply_axes_style(ax, cfg):
    # 坐标轴标签/标题
    ax.set_xlabel("Model size $N$ (B params, log)", **cfg["xlabel"])
    ax.set_ylabel(r"$L_\infty(N)$ (log)", **cfg["ylabel"])
    ax.set_title("All domains: $L_\\infty(N)$ estimates and fits", **cfg["title"])

    # 刻度
    ax.tick_params(**cfg["tick_params"])

    # 脊线
    for side in ["left", "bottom", "right", "top"]:
        ax.spines[side].set_linewidth(cfg["spines"]["linewidth"])
        ax.spines[side].set_color(cfg["spines"]["color"])

    # Y 轴刻度格式
    # if cfg.get("yfmt"):
        # fmt = cfg["yfmt"]
        # ax.yaxis.set_major_formatter(FuncFormatter(lambda v, pos: fmt % v))

def _style_legend(leg, cfg):
    if leg is None:
        return
    leg.set_title(leg.get_title().get_text(), prop={"size": cfg["legend"]["title_fontsize"]})
    frame = leg.get_frame()
    frame.set_linewidth(cfg["legend"]["linewidth"])
    frame.set_edgecolor(cfg["legend"]["edgecolor"])
    frame.set_alpha(cfg["legend"]["framealpha"])
    frame.set_facecolor(cfg["legend"]["facecolor"])

# ===============================
# 3) 选取每个领域的 b_d（最小化全 N 的 SSE）
# ===============================
b_grid = np.linspace(0.0, 1.0, 41)
per_domain = {}
for d in domains:
    best_b, best_sse = None, np.inf
    for b in b_grid:
        sse_total = 0.0
        for N in Ns:
            y = np.array(avg_data[N][d], dtype=float)
            _, _, s = fit_LA_for_b(k_vals, y, b)
            sse_total += s
        if sse_total < best_sse:
            best_sse, best_b = sse_total, b
    per_domain[d] = {"b": float(best_b)}

# ===============================
# 4) 固定 b_d 后估计 L_inf(N) 与早期 A(N)
# ===============================
k_lo, k_hi = EARLY_K_RANGE
mask = (k_vals >= k_lo) & (k_vals <= k_hi)

fits = {}
for d in domains:
    b = per_domain[d]["b"]
    Ls_allN, As_early_allN = [], []
    for N in Ns:
        y = np.array(avg_data[N][d], dtype=float)
        L_all, A_all, _ = fit_LA_for_b(k_vals, y, b)          # 全 k → L_inf
        Ls_allN.append(L_all)
        L_early, A_early, _ = fit_LA_for_b(k_vals[mask], y[mask], b)  # 早期 k → A
        As_early_allN.append(A_early)

    Ls_allN = np.asarray(Ls_allN)
    As_early_allN = np.asarray(As_early_allN)

    A0, gamma, r2_A, Ahat = fit_power_law_A(Ns, As_early_allN)
    Lstar, B, beta, r2_L, Lhat = fit_floor_Linf(Ns, Ls_allN)

    fits[d] = {
        "b": b,
        "L": Ls_allN, "Lhat": Lhat,
        "Aearly": As_early_allN, "Ahat": Ahat,
        "A0": A0, "gamma": gamma, "R2_A": r2_A,
        "Lstar": Lstar, "B": B, "beta": beta, "R2_L": r2_L
    }

# ===============================
# 5) pooled A(N)（仅计算）
# ===============================
A_matrix = np.stack([fits[d]["Aearly"] for d in domains], axis=0)  # [D, len(Ns)]
A_mean = A_matrix.mean(axis=0)
A_std = A_matrix.std(axis=0, ddof=1)
A_se = A_std / np.sqrt(A_matrix.shape[0])
A0_pool, gamma_pool, r2_pool, Ahat_pool = fit_power_law_A(Ns, A_mean)

# ===============================
# 6) 图1：L_inf(N)（使用统一配置）
# ===============================
os.makedirs(OUT_DIR, exist_ok=True)

fig, ax = plt.subplots(figsize=plot_config["figsize"], dpi=plot_config["dpi"])
N_arr = np.array(Ns, dtype=float)

scatter_handles = []
scatter_labels = []

# 数据点与拟合曲线：不显式指定颜色
for d in domains:
    color = color_map[d]
    sc = ax.scatter(
        N_arr, fits[d]["L"],
        marker="o",
        color=color,
        s=plot_config["markersize"]**2,   # matplotlib 的散点用面积，使用大小的平方更直观
        linewidths=plot_config["markeredgewidth"],
        alpha=plot_config["alpha"],
        zorder=plot_config["zorder"],
        label=d
    )
    scatter_handles.append(sc)
    scatter_labels.append(d)
    ax.plot(
        N_arr, fits[d]["Lhat"],
        color=color,
        linestyle="--",
        linewidth=plot_config["linewidth"],
        alpha=plot_config["alpha"],
        zorder=plot_config["zorder"],
    )

ax.set_xscale("log")
ax.set_yscale("log")
import matplotlib.ticker as mticker
from matplotlib.ticker import FuncFormatter, NullFormatter

# 在 ax.set_yscale("log") 之后、_apply_axes_style 之前或之中调用：
ax.yaxis.set_major_locator(mticker.LogLocator(base=10, subs=(1.0, 2.0, 5.0)))
ax.yaxis.set_minor_locator(mticker.LogLocator(base=10, subs=(3.0, 4.0, 6.0, 7.0, 8.0, 9.0)))
ax.yaxis.set_minor_formatter(NullFormatter())  # 隐藏 minor 标签，避免过密

# 两位小数
ax.yaxis.set_major_formatter(FuncFormatter(lambda v, pos: f"{v:.2f}"))
_apply_axes_style(ax, plot_config)
fit_proxy = Line2D([0], [0], linestyle="--", color="#9E9E9E", linewidth=plot_config["linewidth"], label="Fit")
handles = scatter_handles + [fit_proxy]
labels  = scatter_labels + ["Fit"]
# 图例（先创建后细化外观）
leg = ax.legend(
    handles, labels,
    ncol=5,
    loc="upper center",
    bbox_to_anchor=(0.5, -0.18),
    fontsize=plot_config["legend"]["fontsize"],
    # loc=plot_config["legend"]["loc"],
    frameon=plot_config["legend"]["frameon"],
    fancybox=plot_config["legend"]["fancybox"],
    borderpad=plot_config["legend"]["borderpad"],
    handlelength=plot_config["legend"]["handlelength"],
    handletextpad=plot_config["legend"]["handletextpad"],
    labelspacing=plot_config["legend"]["labelspacing"],
)
_style_legend(leg, plot_config)

fig.tight_layout()
out_path = os.path.join(OUT_DIR, "rq1_all_domain_L_inf(N).png")
fig.savefig(out_path, dpi=plot_config["dpi"])
# plt.show()

# ===============================
# 7) 最小摘要（打印）
# ===============================
def avg_ce_at(N, k_idx):
    return float(np.mean([avg_data[N][d][k_idx] for d in domains]))

ce_05_k9 = avg_ce_at(0.5, 8)
ce_32_k9 = avg_ce_at(32.0, 8)
drop_pct = 100.0 * (ce_05_k9 - ce_32_k9) / ce_05_k9

summary_stats = pd.DataFrame({
    "metric": ["domain-avg CE @k=9 (0.5B)", "domain-avg CE @k=9 (32B)", "relative drop %"],
    "value": [ce_05_k9, ce_32_k9, drop_pct]
})
print(summary_stats)
print("\nPooled A(N): A0 = {:.4g}, gamma = {:.3f}, R^2 = {:.3f}".format(A0_pool, gamma_pool, r2_pool))