import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter, LogLocator, NullFormatter
from matplotlib.lines import Line2D

# ===============================
# 0) 统一绘图配置（与代码2一致/相近）
# ===============================
plot_config = {
    "figsize": (8.5, 6.0),
    "dpi": 300,
    "linewidth": 2.5,
    "markersize": 7,            # 注意：scatter 使用面积单位时要平方；line 使用大小本身
    "markeredgewidth": 2.0,
    "xlabel": {"fontsize": 16, "labelpad": 6},
    "ylabel": {"fontsize": 16, "labelpad": 6},
    "title": {"fontsize": 18, "pad": 8},
    "tick_params": {"labelsize": 15},
    "legend": {
        "title_fontsize": 15,
        "fontsize": 12,
        "loc": "upper center",
        "frameon": True,
        "fancybox": True,
        "framealpha": 0.95,
        "borderpad": 0.6,
        "handlelength": 1.8,
        "handletextpad": 0.6,
        "labelspacing": 0.4,
        "edgecolor": "#1f2a35",
        "linewidth": 1.0,
        "facecolor": "white"
    },
    "spines": {"color": "#1f2a35", "linewidth": 3.0},
    "alpha": 0.95,
    "zorder": 2,
}

DOMAIN_COLORS = [
    "#008A45", "#468BCA", "#5F5F5E", "#7DD2F6", "#80C5A2",
    "#B384BA", "#D9C2DD", "#F27873", "#FFD373"
]

OUT_DIR = "figs"

# ===============================
# 1) 数据（与代码1一致）
# ===============================
k_vals = np.arange(1, 10)
Ns = [0.5, 1.5, 3, 7, 14, 32]

avg_data = {
    0.5: {
        "algebra": [0.38870969,0.361388945,0.353070365,0.348855826,0.346396314,0.344687186,0.343482999,0.342608706,0.341813997],
        "discrete": [0.675279706,0.638606598,0.627139191,0.621347593,0.617949222,0.615629954,0.613976834,0.612764024,0.611765409],
        "analysis": [0.400981652,0.373791342,0.365430147,0.361273265,0.358729691,0.357035066,0.355779848,0.354866396,0.35404844],
        "geometry": [0.466163542,0.436566282,0.427406096,0.42283913,0.420065942,0.418225191,0.416887299,0.415889587,0.415204736],
        "code": [0.631458418,0.602949613,0.591226322,0.584666522,0.580381321,0.577346809,0.574992322,0.573220679,0.571673265],
        "number_theory": [0.497854733,0.46661246,0.4569785,0.452169459,0.449364407,0.447499541,0.446133188,0.445094679,0.444341392],
        "chemistry": [1.339087142,1.271470311,1.250391469,1.240096755,1.233782558,1.229949244,1.226851675,1.224914202,1.22309423],
        "physics": [1.334634923,1.270568665,1.251434,1.242290259,1.23670613,1.233275433,1.230599054,1.228866834,1.227004083],
        "biology": [1.609473385,1.52554396,1.498841253,1.485587277,1.477488826,1.472506062,1.468362583,1.465754065,1.463440771],
    },
    1.5: {
        "algebra": [0.315373273,0.295693087,0.292465461,0.288742222,0.285844074,0.281397797,0.278885341,0.277676539,0.276707406],
        "discrete": [0.548971303,0.51829583,0.514443422,0.510867107,0.506271037,0.501898826,0.499072908,0.497471617,0.496613025],
        "analysis": [0.320008975,0.300520823,0.297101809,0.293736669,0.290960189,0.286097292,0.283669042,0.282558608,0.281585139],
        "geometry": [0.381928173,0.361264737,0.354469493,0.350939994,0.34759311,0.34447938,0.341947349,0.341449933,0.34123044],
        "code": [0.496565059,0.472588686,0.469751974,0.469487988,0.465333319,0.458521594,0.459409186,0.460353503,0.458793494],
        "number_theory": [0.40562323,0.380090262,0.377454744,0.373362826,0.369605868,0.364911672,0.362885013,0.361347131,0.360983564],
        "chemistry": [1.081343367,1.045566615,1.004760939,0.993294408,0.987886666,0.992343439,0.987540454,0.9873462,0.985602356],
        "physics": [1.088181642,1.05316764,1.018279327,1.008465584,1.004175439,1.006993868,1.003443619,1.002126369,1.000615201],
        "biology": [1.296567972,1.251977601,1.198715438,1.185429408,1.181860986,1.185602747,1.178655365,1.176602818,1.176086816],
    },
    3.0: {
        "algebra": [0.293407034,0.271431433,0.265069254,0.261705676,0.259903857,0.259435438,0.258632178,0.257932576,0.257598105],
        "discrete": [0.515105431,0.483781504,0.474643103,0.469935137,0.467344589,0.466593629,0.465419291,0.464394998,0.463808979],
        "analysis": [0.297353593,0.275700087,0.269388127,0.266050589,0.264278796,0.263754787,0.262925032,0.262268188,0.261830373],
        "geometry": [0.355216464,0.331970786,0.325250531,0.321728974,0.319887912,0.319428992,0.318623887,0.317852497,0.317383795],
        "code": [0.461050067,0.442670287,0.435650744,0.431110488,0.428018952,0.426525122,0.425159842,0.423967722,0.422765287],
        "number_theory": [0.377820405,0.351564821,0.343931368,0.339920515,0.337873004,0.337196617,0.336171771,0.335377243,0.334936368],
        "chemistry": [1.002853285,0.941914711,0.925199377,0.916051342,0.911097762,0.908694527,0.906686638,0.904922208,0.903670478],
        "physics": [1.01076463,0.953775111,0.937781996,0.929401956,0.924956043,0.922709849,0.920478335,0.918778389,0.917526107],
        "biology": [1.19694596,1.122303113,1.100536454,1.088636071,1.0817045,1.078472269,1.075787083,1.07308901,1.071046765],
    },
    7.0: {
        "algebra": [0.288204221,0.26204617,0.256643353,0.253969777,0.252547993,0.251413784,0.250701174,0.250128568,0.249726416],
        "discrete": [0.479819591,0.441345967,0.432518893,0.428181963,0.425758053,0.423920146,0.422737615,0.421773007,0.421163732],
        "analysis": [0.290004928,0.264157036,0.258188173,0.255194149,0.253549005,0.252328003,0.251520661,0.250857652,0.250593043],
        "geometry": [0.345914588,0.317164791,0.310572916,0.307261623,0.30543702,0.304063303,0.303222131,0.302535498,0.302139551],
        "code": [0.441199812,0.417474524,0.411645293,0.408700994,0.406208481,0.404496178,0.403422294,0.40247628,0.401894785],
        "number_theory": [0.361896146,0.329707016,0.322632794,0.31917792,0.317454128,0.316014357,0.31503946,0.314231925,0.313800938],
        "chemistry": [0.945744303,0.877591883,0.860481984,0.851888154,0.846596452,0.843013343,0.840473939,0.838687015,0.837399648],
        "physics": [0.963366772,0.901523203,0.886857941,0.879699645,0.875654743,0.872681452,0.870705264,0.86922001,0.868353373],
        "biology": [1.129518997,1.047706831,1.025895446,1.014941959,1.008303406,1.003661735,1.000331232,0.998025284,0.996190122],
    },
    14.0: {
        "algebra": [0.272534195,0.238723162,0.232248612,0.229093686,0.225822344,0.222602502,0.22076657,0.219750601,0.219655824],
        "discrete": [0.440893491,0.391182109,0.383570543,0.378847212,0.373675607,0.369465387,0.366878352,0.365555448,0.364534252],
        "analysis": [0.274840045,0.243046082,0.237165347,0.234153955,0.23093587,0.228131163,0.22637307,0.225542161,0.22523505],
        "geometry": [0.324313991,0.288687048,0.278685036,0.27490936,0.270873665,0.26849483,0.266268375,0.265813763,0.265783367],
        "code": [0.39472054,0.367021107,0.363279975,0.363190447,0.358321606,0.350160098,0.350872193,0.351812924,0.349731741],
        "number_theory": [0.342765584,0.301169614,0.295042737,0.29067375,0.286605786,0.282855102,0.281269266,0.280003589,0.279662677],
        "chemistry": [0.843074279,0.784809081,0.746377504,0.732448491,0.725489879,0.728018016,0.721253261,0.719769104,0.716085907],
        "physics": [0.8636146,0.801941284,0.770417437,0.757907068,0.752521354,0.75400138,0.748924754,0.746710994,0.74285281],
        "biology": [1.015637891,0.942581468,0.892687265,0.876038928,0.870265925,0.87056405,0.860771824,0.856969113,0.853540643],
    },
    32.0: {
        "algebra": [0.250056586,0.22819991,0.224164211,0.22059183,0.217122109,0.212983049,0.211003879,0.20998448,0.209667327],
        "discrete": [0.399170451,0.366727561,0.360532187,0.356714675,0.351094845,0.3467244,0.344268693,0.342847831,0.342680329],
        "analysis": [0.249342815,0.228952177,0.224928463,0.221872451,0.218602385,0.214221892,0.212230905,0.21125327,0.210953639],
        "geometry": [0.29852047,0.275523871,0.267949879,0.264155628,0.259822437,0.256552903,0.254432908,0.253909556,0.253894347],
        "code": [0.365842494,0.347083186,0.345486508,0.345554432,0.341585446,0.335027459,0.33629572,0.337481855,0.336340502],
        "number_theory": [0.313214458,0.286150663,0.281955999,0.277909175,0.273475474,0.269118212,0.267366016,0.266089372,0.265844435],
        "chemistry": [0.79198143,0.754100421,0.716174264,0.704870143,0.699072206,0.70178825,0.696492805,0.695831147,0.693996519],
        "physics": [0.808904331,0.771876999,0.739921357,0.730458777,0.725698069,0.727333029,0.723580858,0.72204132,0.720743992],
        "biology": [0.955801372,0.91056614,0.862690604,0.848422185,0.84358317,0.844337765,0.836054734,0.833301106,0.832135028],
    },
}

domains = list(avg_data[0.5].keys())
color_map = {d: DOMAIN_COLORS[i % len(DOMAIN_COLORS)] for i, d in enumerate(domains)}

# ===============================
# 2) 拟合函数（与代码1一致的 A/残差计算）
# ===============================
def fit_LA_for_b(k, y, b):
    """ y ≈ L_inf + A / (k + b) """
    x = 1.0 / (k + b)
    X = np.vstack([np.ones_like(x), x]).T
    coef, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
    L_inf, A = float(coef[0]), float(coef[1])
    y_hat = X @ coef
    sse = float(np.sum((y - y_hat) ** 2))
    return L_inf, A, sse

# ===============================
# 3) 逐领域选 b_d（最小化全 N 的 SSE），并收集 A(N)
# ===============================
b_grid = np.linspace(0.0, 1.0, 41)
per_domain_est = {}

for d in domains:
    best_b, best_sse = None, np.inf
    for b in b_grid:
        total_sse = 0.0
        for N in Ns:
            y = np.array(avg_data[N][d], float)
            _, _, sse = fit_LA_for_b(k_vals, y, b)
            total_sse += sse
        if total_sse < best_sse:
            best_b, best_sse = b, total_sse

    As = []
    for N in Ns:
        y = np.array(avg_data[N][d], float)
        _, A, _ = fit_LA_for_b(k_vals, y, best_b)
        As.append(A)
    per_domain_est[d] = {"b": float(best_b), "A": np.array(As, dtype=float)}

# ===============================
# 4) 绘图：A(N)（点 + 线），风格同代码2
# ===============================
os.makedirs(OUT_DIR, exist_ok=True)
fig, ax = plt.subplots(figsize=plot_config["figsize"], dpi=plot_config["dpi"])

N_arr = np.array(Ns, dtype=float)
scatter_handles, scatter_labels = [], []

for d in domains:
    Avals = per_domain_est[d]["A"]
    color = color_map[d]

    # 数据点
    sc = ax.scatter(
        N_arr, Avals,
        marker="o",
        color=color,
        s=plot_config["markersize"] ** 2,     # scatter 以面积为单位
        linewidths=plot_config["markeredgewidth"],
        alpha=plot_config["alpha"],
        zorder=plot_config["zorder"],
        label=d
    )
    scatter_handles.append(sc)
    scatter_labels.append(d)

    # 连线（逻辑保持：仅连接相邻 N 上的 A(N)）
    ax.plot(
        N_arr, Avals,
        color=color,
        linestyle="-",
        linewidth=plot_config["linewidth"],
        alpha=plot_config["alpha"],
        zorder=plot_config["zorder"],
    )

# 坐标轴为对数尺度（与代码1一致）
ax.set_xscale("log")
ax.set_yscale("log")

# Y 轴刻度更易读：主刻度 1,2,5；隐藏次刻度标签
ax.yaxis.set_major_locator(LogLocator(base=10, subs=(1.0, 2.0, 5.0)))
ax.yaxis.set_minor_locator(LogLocator(base=10, subs=(3.0, 4.0, 6.0, 7.0, 8.0, 9.0)))
ax.yaxis.set_minor_formatter(NullFormatter())
ax.yaxis.set_major_formatter(FuncFormatter(lambda v, pos: f"{v:.2f}"))

# 坐标轴/标题/标签与脊线样式
ax.set_xlabel("Model size $N$ (B params, log)", **plot_config["xlabel"])
ax.set_ylabel(r"$A(N)$ (log)", **plot_config["ylabel"])
ax.set_title("All domains: $A(N)$ estimates (points + lines)", **plot_config["title"])
ax.tick_params(**plot_config["tick_params"])
for side in ["left", "bottom", "right", "top"]:
    ax.spines[side].set_linewidth(plot_config["spines"]["linewidth"])
    ax.spines[side].set_color(plot_config["spines"]["color"])

# 图例：放到下方，保持与代码2一致的外观
leg = ax.legend(
    scatter_handles, scatter_labels,
    ncol=5,
    loc="upper center",
    bbox_to_anchor=(0.5, -0.18),
    fontsize=plot_config["legend"]["fontsize"],
    frameon=plot_config["legend"]["frameon"],
    fancybox=plot_config["legend"]["fancybox"],
    borderpad=plot_config["legend"]["borderpad"],
    handlelength=plot_config["legend"]["handlelength"],
    handletextpad=plot_config["legend"]["handletextpad"],
    labelspacing=plot_config["legend"]["labelspacing"],
)
# 细化 legend 框线与底色
frame = leg.get_frame()
frame.set_linewidth(plot_config["legend"]["linewidth"])
frame.set_edgecolor(plot_config["legend"]["edgecolor"])
frame.set_alpha(plot_config["legend"]["framealpha"])
frame.set_facecolor(plot_config["legend"]["facecolor"])

fig.tight_layout()
out_path = os.path.join(OUT_DIR, "rq1_all_domain_A(N).png")
fig.savefig(out_path, dpi=plot_config["dpi"])
# plt.show()

# ===============================
# 5) 轻量输出：打印每个领域的最优 b
# ===============================
for d in domains:
    print(f"{d:>14s}  best_b = {per_domain_est[d]['b']:.3f}")
print(f"\nSaved figure to: {out_path}")