# -*- coding: utf-8 -*-
"""
Joint fit & plotting for CE ~ f(N, k)

- 数据准备、拟合、预测、制图与存盘模块化
- 使用 Pathlib 管理路径
- 轴标题与图标题通过参数传入
- 颜色固定为用户指定调色板（可覆盖），曲线与散点同色
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Tuple, Dict, Any, Optional, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
from matplotlib.lines import Line2D
from scipy.optimize import curve_fit

# -----------------------------
# 数据与模型
# -----------------------------
def to_dataframe(rows: Iterable[Dict[str, Any]]) -> pd.DataFrame:
    """将行数据转成 DataFrame 并进行基础校验/排序。"""
    df = pd.DataFrame(rows).copy()
    required = {"N", "k", "CE"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"Missing required columns: {missing}")

    df = df.astype({"N": float, "k": float, "CE": float})
    return df.sort_values(["N", "k"]).reset_index(drop=True)

def joint_model(X: Tuple[np.ndarray, np.ndarray],
                Ls: float, B: float, beta: float,
                A0: float, gamma: float, b0: float) -> np.ndarray:
    """
    目标模型：
        CE(N, k) = Ls + B * N^(-beta) + (A0 * N^(-gamma)) / (k + b0)
    X: (N_array, k_array)
    返回：预测 CE
    """
    N, k = X
    return Ls + B * (N ** (-beta)) + (A0 * (N ** (-gamma))) / (k + b0)

@dataclass
class FitResult:
    Ls: float
    B: float
    beta: float
    A0: float
    gamma: float
    b0: float
    rmse: float
    r2: float
    popt_cov: Optional[np.ndarray] = None  # 来自 curve_fit 的协方差矩阵

# -----------------------------
# 拟合、评估与预测
# -----------------------------
def fit_joint_model(df: pd.DataFrame,
                    p0: Iterable[float] = (0.2, 0.6, 0.3, 0.7, 0.2, 0.3),
                    bounds: Tuple[Iterable[float], Iterable[float]] = ((0, 0, 0, 0, 0, 0), (5, 5, 5, 5, 5, 5)),
                    maxfev: int = 300_000) -> FitResult:
    """对给定数据拟合联合模型并返回参数与拟合质量。"""
    X = np.vstack([df["N"].values, df["k"].values])
    y = df["CE"].values

    popt, pcov = curve_fit(joint_model, X, y, p0=p0, bounds=bounds, maxfev=maxfev)
    Ls, B, beta, A0, gamma, b0 = popt

    yhat = joint_model(X, *popt)
    resid = y - yhat
    rmse = float(np.sqrt(np.mean(resid ** 2)))
    r2 = 1.0 - float(np.sum(resid ** 2)) / float(np.sum((y - np.mean(y)) ** 2))

    return FitResult(Ls=Ls, B=B, beta=beta, A0=A0, gamma=gamma, b0=b0,
                     rmse=rmse, r2=r2, popt_cov=pcov)

def predict_by_k(N_value: float,
                 k_values: Iterable[float],
                 params: FitResult) -> pd.DataFrame:
    """在固定 N 下，生成不同 k 的 CE 预测表。"""
    k_arr = np.array(list(k_values), dtype=float)
    N_arr = np.full_like(k_arr, float(N_value), dtype=float)
    ce = joint_model((N_arr, k_arr),
                     params.Ls, params.B, params.beta, params.A0, params.gamma, params.b0)
    return pd.DataFrame({"k": k_arr, f"CE@{N_value}B": ce})

# -----------------------------
# 绘图（打包为函数）
# -----------------------------

# 默认调色盘采用 Okabe–Ito（色弱友好、对比清晰）
OKABE_ITO_PALETTE = [
    "#004c6d",  # black
    "#5390a6",  # orange
    "#79b4c3",  # sky blue
    "#009E73",  # bluish green
    "#F0E442",  # yellow
    "#0072B2",  # blue
    "#D55E00",  # vermillion
    "#a1d9e0",  # reddish purple
]

_DEFAULT_PALETTE = OKABE_ITO_PALETTE

def _format_B(N: float) -> str:
    """把 72 -> 72B，1.5 -> 1.5B（不去小数）"""
    if float(N).is_integer():
        return f"{int(N)}B"
    return f"{N}B"

# ★★★ changed: 基于“排序”的显式颜色分配（最大 N 用首色，最小 N 用末色）
def _build_rank_color_map(
    N_values: List[float],
    palette: List[str],
    color_map_override: Optional[Dict[float, str]] = None,
) -> Dict[float, str]:
    """
    构造颜色映射（显式控制）：
      - 若提供 color_map_override，优先使用其中 N→颜色的映射；
      - 其余按 N 从大到小分配：
          max(N) → palette[0]，min(N) → palette[-1]，
          中间值顺序占用 palette[1:-1]（循环复用以适配更多曲线）。
    """
    if not palette:
        palette = _DEFAULT_PALETTE

    Ns_sorted_desc = sorted(set(float(n) for n in N_values), reverse=True)
    cmap: Dict[float, str] = {}

    # 显式覆盖优先
    if color_map_override:
        for n, c in color_map_override.items():
            cmap[float(n)] = c

    # 过滤掉已覆盖的 N，再分配
    remaining = [n for n in Ns_sorted_desc if n not in cmap]

    if not remaining:
        return cmap

    if len(remaining) == 1:
        cmap[remaining[0]] = palette[0]
        return cmap

    # 端点
    maxN, minN = remaining[0], remaining[-1]
    cmap[maxN] = palette[0]
    cmap[minN] = palette[-1]

    middle_Ns = remaining[1:-1]
    middle_palette = palette[1:-1] if len(palette) >= 3 else palette

    for i, n in enumerate(middle_Ns):
        cmap[n] = middle_palette[i % len(middle_palette)]

    return cmap

def plot_ce_vs_N(
    df: pd.DataFrame,
    params: FitResult,
    domain: str,
    k_curves: Iterable[float],
    xlabel: str,
    ylabel: str,
    title: str,
    colors: Optional[List[str]] = None,
    color_map_override: Optional[Dict[float, str]] = None,
    N_min: float = None,
    N_max: float = None,
    points_marker_size: float = 128.0,
    figsize: Tuple[float, float] = (8.2, 6.0),
    dpi: int = 300
) -> plt.Figure:
    """
    Plot CE vs N:
      - 每个 k 画一条预测曲线（随 N 变化），叠加散点
      - Legend：不同专家数 k
    """
    palette = colors or _DEFAULT_PALETTE
    fig = plt.figure(figsize=figsize, dpi=dpi)
    ax = fig.add_subplot(111)

    for sp in ax.spines.values():
        sp.set_color("#1f2a35")
        sp.set_linewidth(3.0)

    Ns_in_data = np.unique(df["N"].values).astype(float)
    if N_min is None:
        N_min = float(Ns_in_data.min())
    if N_max is None:
        N_max = float(Ns_in_data.max())

    NN = np.linspace(N_min, N_max, 200)

    # 为不同的 k 分配颜色
    color_map = _build_rank_color_map(k_curves, palette, color_map_override)

    # —— 曲线与散点 —— 
    for k in sorted(k_curves):
        curve = joint_model((NN, np.full_like(NN, k, dtype=float)),
                            params.Ls, params.B, params.beta, params.A0, params.gamma, params.b0)
        color = color_map.get(float(k), palette[0])

        ax.plot(
            NN, curve,
            linewidth=3.5,
            linestyle=":",
            color=color,
            zorder=2,
            alpha=0.8
        )

        if np.any(np.isclose(df["k"], k)):
            pts = df[df["k"] == k].sort_values("N")
            ax.scatter(
                pts["N"], pts["CE"],
                s=points_marker_size,
                color=color,
                edgecolors="white",
                linewidths=0.8,
                zorder=3
            )

    ax.set_xlabel(xlabel, fontsize=18, labelpad=6)
    ax.set_ylabel(ylabel, fontsize=16, labelpad=6)
    ax.set_title(title, fontsize=20)
    ax.tick_params(labelsize=13)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    # —— Legend —— 
    handles: List[Line2D] = []
    for k in sorted(set(k_curves)):
        color = color_map.get(float(k), palette[0])
        handles.append(Line2D(
            [0], [0],
            marker='o',
            linestyle='None',
            markersize=10.0,
            markerfacecolor=color,
            markeredgecolor="white",
            markeredgewidth=0.8,
            label=f"k={int(k)}"
        ))

    handles.append(Line2D(
        [0, 1], [0, 0],
        linestyle=':',
        linewidth=3.2,
        color="#6E7781",
        label="Fit"
    ))

    leg = ax.legend(
        handles=handles,
        title="Experts",
        title_fontsize=16,
        fontsize=14,
        loc="best",
        frameon=True,
        fancybox=True,
        framealpha=0.95,
        borderpad=0.6,
        handlelength=1.8,
        handletextpad=0.6,
        labelspacing=0.4
    )
    leg.get_frame().set_edgecolor("#1f2a35")
    leg.get_frame().set_linewidth(1.0)
    leg.get_frame().set_facecolor("white")

    fig.tight_layout()
    return fig

# -----------------------------
# I/O 工具
# -----------------------------
def save_dataframe_csv(df: pd.DataFrame, path: Path, index: bool = False) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(path, index=index)

def save_figure(fig: plt.Figure, path: Path, dpi: int = 170) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(path, dpi=dpi)
    plt.close(fig)

# -----------------------------
# 一站式运行入口
# -----------------------------
def run_pipeline(rows: Iterable[Dict[str, Any]],
                 domain: str,
                 output_dir: Path,
                 xlabel: str = "Merged experts (k)",
                 ylabel: Optional[str] = None,
                 title: Optional[str] = None,
                 colors: Optional[List[str]] = None,
                 N_pred: Optional[float] = None,
                 k_grid: Iterable[int] = range(1, 10),
                 color_map_override: Optional[Dict[float, str]] = None
                 ) -> Dict[str, Path]:
    """
    端到端：
      1) 数据 -> 拟合 -> 评估
      2) 固定 N_pred，随 k 预测并落盘
      3) 画 N_curves 的曲线（包含 N_pred 与数据里的 N）
      4) 保存概要 CSV（参数与指标）
    返回各输出文件的路径字典。
    """
    df = to_dataframe(rows)
    fitres = fit_joint_model(df)

    # 预测表
    if N_pred is not None:
        pred_df = predict_by_k(N_pred, k_grid, fitres)
        pred_path = output_dir / f"pred_{int(N_pred) if float(N_pred).is_integer() else N_pred}B_by_k_{domain}.csv"
        save_dataframe_csv(pred_df, pred_path)

    # 绘图：包含数据中的 N 与 N_pred
    Ns_in_data = sorted(df["N"].unique().tolist())
    N_curves = Ns_in_data.copy()
    if N_pred is not None and float(N_pred) not in [float(x) for x in N_curves]:
        N_curves = N_curves + [float(N_pred)]

    fig = plot_ce_vs_N(
        df=df,
        params=fitres,
        domain="Overall",
        k_curves=[1, 3, 6, 9],   # ★ 控制显示的专家数
        xlabel="Model Size (B)",
        ylabel="CE Loss",
        title=title
    )
    fig_path = output_dir / f"{domain}_CE_vs_N.png"
    save_figure(fig, fig_path)

    # 概要表
    summary_df = pd.DataFrame([{
        "domain": domain,
        "L*": fitres.Ls,
        "B": fitres.B,
        "beta": fitres.beta,
        "A0": fitres.A0,
        "gamma": fitres.gamma,
        "b0": fitres.b0,
        "RMSE": fitres.rmse,
        "R2": fitres.r2
    }]).sort_values("domain").reset_index(drop=True)
    summary_path = output_dir / "per_domain_joint_fit_summary.csv"
    save_dataframe_csv(summary_df, summary_path)

    return {
        "figure_png": fig_path,
        "summary_csv": summary_path
    }

# -----------------------------
# 示例用法
# -----------------------------
if __name__ == "__main__":
    dom = "Algebra"
    XLABEL = "Number of Experts for Merging (K)"
    YLABEL = "CE Loss"
    TITLE  = f"{dom}: CE vs. Model Size"

    import json

    # 读取 JSON 文件
    with open("data.json", "r") as f:
        data = json.load(f)

    rows = []

    # 遍历 "AVG" 下的每个 N 值
    for n_str, content in data["AVG"].items():
        # 去掉 B 并转成浮点数
        # if n_str.startswith("72"):
        #     continue
        N = float(n_str.replace("B", ""))
        
        # 遍历 Overall 数组，索引从 1 开始
        for k, val in enumerate(content[dom], start=1):
            # 构造目标字典格式
            row = {
                "N": N,
                "k": k,
                "CE": val 
            }
            rows.append(row)

    out_dir = Path("figs")

    # 可选：自定义调色盘（覆盖默认 Okabe–Ito）
    color_list = ['#196271', '#2e7875', '#428f76', '#57a676', '#6cbd74']
    # 可选：显式 N→颜色映射覆盖（只需要指定想固定的那些 N）
    # 例如：把 72.7B 固定用蓝色、0.5B 固定用橙色
    color_map_override = {
        72.7: "#004c6d",
        0.5:  "#81d470",
    }

    paths = run_pipeline(
        rows=rows,
        domain=dom,
        output_dir=out_dir,
        xlabel=XLABEL,
        ylabel=YLABEL,
        title=TITLE,
        colors=color_list,
        N_pred=None,
        color_map_override=color_map_override  # ★★★ changed: 显式控制颜色映射
    )

    for k, v in paths.items():
        print(f"{k}: {v.resolve()}")