# -*- coding: utf-8 -*-
"""
Plot utilities for R(k) study:
- Panel A: median R(k) curves (all / math-like / science-like) with IQR band
- Panel B: k90 heatmap across domains x model sizes

数据与统计计算部分请在占位区域补上你原先的代码，保证以下变量已就绪：
- k_vals: 1D array-like of k
- med_all, q25_all, q75_all, med_math, med_sci: 1D arrays aligned with k_vals
- Ns: list of model sizes (x 轴用于热图的列标)
- domains_order: list of domains (y 轴用于热图的行标)
- k90_mat: 2D array of shape (len(domains_order), len(Ns))
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

# -----------------------------
# 统一的绘图配置（可直接粘贴你的配置）
# -----------------------------
plot_config = {
    "figsize": (7.2, 5.0),
    "dpi": 300,
    "linewidth": 4.5,
    "markersize": 12,
    "markeredgewidth": 2.5,
    "xlabel": {"fontsize": 16, "labelpad": 6},
    "ylabel": {"fontsize": 16, "labelpad": 6},
    "title": {"fontsize": 18, "pad": 8},
    "tick_params": {"labelsize": 15},
    "legend": {
        "title_fontsize": 15,
        "fontsize": 12,
        "loc": "best",
        "frameon": True,
        "fancybox": True,
        "framealpha": 0.95,
        "borderpad": 0.6,
        "handlelength": 1.8,
        "handletextpad": 0.6,
        "labelspacing": 0.4,
        "edgecolor": "#1f2a35",
        "linewidth": 1.0,
        "facecolor": "white"
    },
    "spines": {"color": "#1f2a35", "linewidth": 3.0},
    "alpha": 0.95,
    "zorder": 2,
    "yfmt": "%.02f",
}


# -----------------------------
# Data
# -----------------------------
k_vals = np.arange(1, 10)
Ns = [0.5, 1.5, 3.0, 7.0, 14.0, 32.0]

avg_data = {
    0.5: {
        "algebra": [0.38870969,0.361388945,0.353070365,0.348855826,0.346396314,0.344687186,0.343482999,0.342608706,0.341813997],
        "discrete": [0.675279706,0.638606598,0.627139191,0.621347593,0.617949222,0.615629954,0.613976834,0.612764024,0.611765409],
        "analysis": [0.400981652,0.373791342,0.365430147,0.361273265,0.358729691,0.357035066,0.355779848,0.354866396,0.35404844],
        "geometry": [0.466163542,0.436566282,0.427406096,0.42283913,0.420065942,0.418225191,0.416887299,0.415889587,0.415204736],
        "code": [0.631458418,0.602949613,0.591226322,0.584666522,0.580381321,0.577346809,0.574992322,0.573220679,0.571673265],
        "number_theory": [0.497854733,0.46661246,0.4569785,0.452169459,0.449364407,0.447499541,0.446133188,0.445094679,0.444341392],
        "chemistry": [1.339087142,1.271470311,1.250391469,1.240096755,1.233782558,1.229949244,1.226851675,1.224914202,1.22309423],
        "physics": [1.334634923,1.270568665,1.251434,1.242290259,1.23670613,1.233275433,1.230599054,1.228866834,1.227004083],
        "biology": [1.609473385,1.52554396,1.498841253,1.485587277,1.477488826,1.472506062,1.468362583,1.465754065,1.463440771],
    },
    1.5: {
        "algebra": [0.315373273,0.295693087,0.292465461,0.288742222,0.285844074,0.281397797,0.278885341,0.277676539,0.276707406],
        "discrete": [0.548971303,0.51829583,0.514443422,0.510867107,0.506271037,0.501898826,0.499072908,0.497471617,0.496613025],
        "analysis": [0.320008975,0.300520823,0.297101809,0.293736669,0.290960189,0.286097292,0.283669042,0.282558608,0.281585139],
        "geometry": [0.381928173,0.361264737,0.354469493,0.350939994,0.34759311,0.34447938,0.341947349,0.341449933,0.34123044],
        "code": [0.496565059,0.472588686,0.469751974,0.469487988,0.465333319,0.458521594,0.459409186,0.460353503,0.458793494],
        "number_theory": [0.40562323,0.380090262,0.377454744,0.373362826,0.369605868,0.364911672,0.362885013,0.361347131,0.360983564],
        "chemistry": [1.081343367,1.045566615,1.004760939,0.993294408,0.987886666,0.992343439,0.987540454,0.9873462,0.985602356],
        "physics": [1.088181642,1.05316764,1.018279327,1.008465584,1.004175439,1.006993868,1.003443619,1.002126369,1.000615201],
        "biology": [1.296567972,1.251977601,1.198715438,1.185429408,1.181860986,1.185602747,1.178655365,1.176602818,1.176086816],
    },
    3.0: {
        "algebra": [0.293407034,0.271431433,0.265069254,0.261705676,0.259903857,0.259435438,0.258632178,0.257932576,0.257598105],
        "discrete": [0.515105431,0.483781504,0.474643103,0.469935137,0.467344589,0.466593629,0.465419291,0.464394998,0.463808979],
        "analysis": [0.297353593,0.275700087,0.269388127,0.266050589,0.264278796,0.263754787,0.262925032,0.262268188,0.261830373],
        "geometry": [0.355216464,0.331970786,0.325250531,0.321728974,0.319887912,0.319428992,0.318623887,0.317852497,0.317383795],
        "code": [0.461050067,0.442670287,0.435650744,0.431110488,0.428018952,0.426525122,0.425159842,0.423967722,0.422765287],
        "number_theory": [0.377820405,0.351564821,0.343931368,0.339920515,0.337873004,0.337196617,0.336171771,0.335377243,0.334936368],
        "chemistry": [1.002853285,0.941914711,0.925199377,0.916051342,0.911097762,0.908694527,0.906686638,0.904922208,0.903670478],
        "physics": [1.01076463,0.953775111,0.937781996,0.929401956,0.924956043,0.922709849,0.920478335,0.918778389,0.917526107],
        "biology": [1.19694596,1.122303113,1.100536454,1.088636071,1.0817045,1.078472269,1.075787083,1.07308901,1.071046765],
    },
    7.0: {
        "algebra": [0.288204221,0.26204617,0.256643353,0.253969777,0.252547993,0.251413784,0.250701174,0.250128568,0.249726416],
        "discrete": [0.479819591,0.441345967,0.432518893,0.428181963,0.425758053,0.423920146,0.422737615,0.421773007,0.421163732],
        "analysis": [0.290004928,0.264157036,0.258188173,0.255194149,0.253549005,0.252328003,0.251520661,0.250857652,0.250593043],
        "geometry": [0.345914588,0.317164791,0.310572916,0.307261623,0.30543702,0.304063303,0.303222131,0.302535498,0.302139551],
        "code": [0.441199812,0.417474524,0.411645293,0.408700994,0.406208481,0.404496178,0.403422294,0.40247628,0.401894785],
        "number_theory": [0.361896146,0.329707016,0.322632794,0.31917792,0.317454128,0.316014357,0.31503946,0.314231925,0.313800938],
        "chemistry": [0.945744303,0.877591883,0.860481984,0.851888154,0.846596452,0.843013343,0.840473939,0.838687015,0.837399648],
        "physics": [0.963366772,0.901523203,0.886857941,0.879699645,0.875654743,0.872681452,0.870705264,0.86922001,0.868353373],
        "biology": [1.129518997,1.047706831,1.025895446,1.014941959,1.008303406,1.003661735,1.000331232,0.998025284,0.996190122],
    },
    14.0: {
        "algebra": [0.272534195,0.238723162,0.232248612,0.229093686,0.225822344,0.222602502,0.22076657,0.219750601,0.219655824],
        "discrete": [0.440893491,0.391182109,0.383570543,0.378847212,0.373675607,0.369465387,0.366878352,0.365555448,0.364534252],
        "analysis": [0.274840045,0.243046082,0.237165347,0.234153955,0.23093587,0.228131163,0.22637307,0.225542161,0.22523505],
        "geometry": [0.324313991,0.288687048,0.278685036,0.27490936,0.270873665,0.26849483,0.266268375,0.265813763,0.265783367],
        "code": [0.39472054,0.367021107,0.363279975,0.363190447,0.358321606,0.350160098,0.350872193,0.351812924,0.349731741],
        "number_theory": [0.342765584,0.301169614,0.295042737,0.29067375,0.286605786,0.282855102,0.281269266,0.280003589,0.279662677],
        "chemistry": [0.843074279,0.784809081,0.746377504,0.732448491,0.725489879,0.728018016,0.721253261,0.719769104,0.716085907],
        "physics": [0.8636146,0.801941284,0.770417437,0.757907068,0.752521354,0.75400138,0.748924754,0.746710994,0.74285281],
        "biology": [1.015637891,0.942581468,0.892687265,0.876038928,0.870265925,0.87056405,0.860771824,0.856969113,0.853540643],
    },
    32.0: {
        "algebra": [0.250056586,0.22819991,0.224164211,0.22059183,0.217122109,0.212983049,0.211003879,0.20998448,0.209667327],
        "discrete": [0.399170451,0.366727561,0.360532187,0.356714675,0.351094845,0.3467244,0.344268693,0.342847831,0.342680329],
        "analysis": [0.249342815,0.228952177,0.224928463,0.221872451,0.218602385,0.214221892,0.212230905,0.21125327,0.210953639],
        "geometry": [0.29852047,0.275523871,0.267949879,0.264155628,0.259822437,0.256552903,0.254432908,0.253909556,0.253894347],
        "code": [0.365842494,0.347083186,0.345486508,0.345554432,0.341585446,0.335027459,0.33629572,0.337481855,0.336340502],
        "number_theory": [0.313214458,0.286150663,0.281955999,0.277909175,0.273475474,0.269118212,0.267366016,0.266089372,0.265844435],
        "chemistry": [0.79198143,0.754100421,0.716174264,0.704870143,0.699072206,0.70178825,0.696492805,0.695831147,0.693996519],
        "physics": [0.808904331,0.771876999,0.739921357,0.730458777,0.725698069,0.727333029,0.723580858,0.72204132,0.720743992],
        "biology": [0.955801372,0.91056614,0.862690604,0.848422185,0.84358317,0.844337765,0.836054734,0.833301106,0.832135028],
    },
}

domains = list(avg_data[0.5].keys())
math_domains = ["algebra", "analysis", "geometry", "number_theory", "discrete"]
science_domains = ["code", "chemistry", "physics", "biology"]

# -----------------------------
# Helpers
# -----------------------------
def monotone_envelope(vals):
    """Return the nonincreasing (monotone) envelope via cumulative minima."""
    return np.minimum.accumulate(np.array(vals, dtype=float))

def fractional_return(L_series):
    """Compute fractional return series R(k) from loss/score series."""
    L_env = monotone_envelope(L_series)
    L1 = L_env[0]
    L9 = L_env[-1]
    denom = L1 - L9
    if denom <= 0 or not np.isfinite(denom):
        return np.zeros_like(L_env)
    return (L1 - L_env) / denom

# Precompute R for each (N, domain)
R_by_pair = {(N, d): fractional_return(avg_data[N][d]) for N in Ns for d in domains}

def stack_R_for(pairs):
    """Stack R across pairs and compute median and IQR."""
    M = np.vstack([R_by_pair[p] for p in pairs])
    med = np.median(M, axis=0)
    q25 = np.percentile(M, 25, axis=0)
    q75 = np.percentile(M, 75, axis=0)
    return med, q25, q75

all_pairs = [(N, d) for N in Ns for d in domains]
math_pairs = [(N, d) for N in Ns for d in math_domains]
sci_pairs = [(N, d) for N in Ns for d in science_domains]

med_all, q25_all, q75_all = stack_R_for(all_pairs)
med_math, _, _ = stack_R_for(math_pairs)
med_sci, _, _ = stack_R_for(sci_pairs)

def first_k_reach(R, thr):
    """First k where R(k) >= thr; else return max k."""
    idx = np.argmax(R >= thr)
    if R[idx] >= thr:
        return int(k_vals[idx])
    return int(k_vals[-1])

# -----------------------------
# Tables
# -----------------------------
rows = []
for N in Ns:
    for d in domains:
        R = R_by_pair[(N, d)]
        k85 = first_k_reach(R, 0.85)
        k90 = first_k_reach(R, 0.90)
        rows.append({
            "N(B)": N,
            "domain": d,
            "k85": k85,
            "k90": k90,
            "R6": float(R[5-1]),          # R at k=5
            "residual_1-R6": float(1.0 - R[5-1])
        })
k_table = pd.DataFrame(rows).sort_values(["domain", "N(B)"])

Rk_df = pd.DataFrame({
    "k": k_vals,
    "median_all": med_all,
    "q25_all": q25_all,
    "q75_all": q75_all,
    "median_math": med_math,
    "median_science": med_sci
})

# Save tables in working directory
k_table_path = os.path.join(os.getcwd(), "rq2_k_elbows.csv")
Rk_table_path = os.path.join(os.getcwd(), "rq2_Rk_median_iqr.csv")
k_table.to_csv(k_table_path, index=False)
Rk_df.to_csv(Rk_table_path, index=False)
domains_order = ["algebra","analysis","geometry","number_theory","discrete","code","chemistry","physics","biology"]
k90_mat = np.zeros((len(domains_order), len(Ns)), dtype=float)


# -----------------------------
# IO
# -----------------------------
def ensure_dir(path: str) -> str:
    os.makedirs(path, exist_ok=True)
    return path

FIG_DIR = ensure_dir(os.path.join(os.getcwd(), "fig"))

# -----------------------------
# 样式 & 工具
# -----------------------------
def _apply_spines(ax, cfg):
    for side in ["left", "bottom", "right", "top"]:
        ax.spines[side].set_linewidth(cfg["spines"]["linewidth"])
        ax.spines[side].set_color(cfg["spines"]["color"])

def _apply_ticks(ax, cfg):
    ax.tick_params(**cfg["tick_params"])

def _apply_labels(ax, xlabel, ylabel, title, cfg):
    ax.set_xlabel(xlabel, **cfg["xlabel"])
    ax.set_ylabel(ylabel, **cfg["ylabel"])
    ax.set_title(title, **cfg["title"])

def _apply_y_formatter(ax, cfg):
    if cfg.get("yfmt"):
        fmt = cfg["yfmt"]
        ax.yaxis.set_major_formatter(FuncFormatter(lambda v, pos: fmt % v))

def _finalize(ax, cfg, xticks=None, yticks=None, tight=True):
    if xticks is not None:
        ax.set_xticks(xticks)
    if yticks is not None:
        ax.set_yticks(yticks)
    _apply_spines(ax, cfg)
    _apply_ticks(ax, cfg)
    if tight:
        plt.tight_layout()

def save_figure(fig, out_path, dpi):
    fig.savefig(out_path, dpi=dpi)
    plt.close(fig)

def _legend(ax, cfg, title=None):
    leg = ax.legend(title=title,
                    loc=cfg["legend"]["loc"],
                    fontsize=cfg["legend"]["fontsize"],
                    frameon=cfg["legend"]["frameon"],
                    fancybox=cfg["legend"]["fancybox"],
                    framealpha=cfg["legend"]["framealpha"],
                    borderpad=cfg["legend"]["borderpad"],
                    handlelength=cfg["legend"]["handlelength"],
                    handletextpad=cfg["legend"]["handletextpad"],
                    labelspacing=cfg["legend"]["labelspacing"])
    if leg is not None and leg.get_frame() is not None:
        leg.get_frame().set_edgecolor(cfg["legend"]["edgecolor"])
        leg.get_frame().set_linewidth(cfg["legend"]["linewidth"])
        leg.get_frame().set_facecolor(cfg["legend"]["facecolor"])

# -----------------------------
# Panel A: 折线 + IQR 带
# -----------------------------
def plot_panel_a(k_vals,
                 med_all, q25_all, q75_all,
                 med_math, med_sci,
                 cfg=plot_config,
                 out_path=None):
    """
    绘制 R(k) 中位数曲线（全体/数学类/科学类）与全体 IQR 带。
    不指定颜色，遵循“仅用 matplotlib、单图、默认颜色”的要求。
    """
    fig = plt.figure(figsize=cfg["figsize"], dpi=cfg["dpi"])
    ax = fig.add_subplot(111)

    # IQR 带
    ax.fill_between(k_vals, q25_all, q75_all,
                    alpha=0.2, zorder=cfg["zorder"]-1,
                    label="IQR (all)")

    # 中位数曲线
    ax.plot(k_vals, med_all,
            linewidth=cfg["linewidth"],
            alpha=cfg["alpha"],
            zorder=cfg["zorder"],
            label="Median $R(k)$ (all)")

    ax.plot(k_vals, med_math,
            linewidth=cfg["linewidth"],
            alpha=cfg["alpha"],
            zorder=cfg["zorder"],
            linestyle="--",
            label="Median (math-like)")

    ax.plot(k_vals, med_sci,
            linewidth=cfg["linewidth"],
            alpha=cfg["alpha"],
            zorder=cfg["zorder"],
            linestyle=":",
            label="Median (science-like)")

    # 参考线
    ax.axhline(0.85, linestyle="--", linewidth=1.5, alpha=0.8)
    ax.axhline(0.90, linestyle="--", linewidth=1.5, alpha=0.8)

    # 轴与标题
    _apply_labels(ax,
                  xlabel="Number of merged experts $k$",
                  ylabel="Fractional return $R(k)$",
                  title="RQ2: Fractional return vs. $k$ (in-domain Average)",
                  cfg=cfg)
    _apply_y_formatter(ax, cfg)
    _legend(ax, cfg)
    _finalize(ax, cfg, xticks=k_vals)

    if out_path is not None:
        save_figure(fig, out_path, cfg["dpi"])
    else:
        return fig, ax

# -----------------------------
# 导出：与原脚本一致的输出位置
# -----------------------------
if __name__ == "__main__":
    # 假设此处你已准备好所需数据变量
    # from your_stats_block import k_vals, med_all, q25_all, q75_all, med_math, med_sci, Ns, domains_order, k90_mat
    # 这里用断言提示变量需要存在
    required = ["k_vals","med_all","q25_all","q75_all","med_math","med_sci","Ns","domains_order","k90_mat"]
    missing = [v for v in required if v not in globals()]
    if missing:
        raise RuntimeError(f"缺少变量，请先在数据部分生成：{missing}")

    fig_a_path = os.path.join(FIG_DIR, "rq2_panel_a_median_Rk.png")
    fig_b_path = os.path.join(FIG_DIR, "rq2_panel_b_k90_heatmap.png")

    plot_panel_a(k_vals, med_all, q25_all, q75_all, med_math, med_sci,
                 cfg=plot_config, out_path=fig_a_path)
