# -*- coding: utf-8 -*-
"""
File: plot_rq5_side_by_side.py
Purpose: Generate the RQ5 cross-domain "side-by-side" figure (Real vs 3-point forecast).
Inputs:
  - all_methods.xlsx with a 'mean' sheet as described.
Outputs:
  - figs/rq5_all_data_vs_fits.png
Notes:
  - Pure Matplotlib (no seaborn). One figure with two axes.
  - This file is **independent**; it duplicates loader/utilities so it can run standalone.
"""
from __future__ import annotations

import os
import re
import argparse
from typing import List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


# ========= I/O dirs =========
os.makedirs("figs", exist_ok=True)
os.makedirs("out", exist_ok=True)


# ========= Light Matplotlib styling =========
def _apply_style() -> None:
    plt.rcParams.update({
        "figure.dpi": 160,
        "savefig.dpi": 220,
        "figure.facecolor": "white",
        "axes.facecolor": "white",
        "axes.edgecolor": "#222222",
        "axes.labelcolor": "#222222",
        # ---- 关键修改：不要在全局打开网格，避免出现竖线 ----
        "axes.grid": False,
        # ---- 关键修改：加粗坐标轴框线（spines）----
        "axes.linewidth": 3.0,
        "grid.color": "#cccccc",
        "grid.linestyle": "-",
        "grid.linewidth": 0.8,
        "grid.alpha": 0.6,
        "axes.spines.top": False,
        "axes.spines.right": False,
        "lines.linewidth": 2.0,
        "legend.frameon": True,
        "legend.fontsize": 12,
        "font.size": 11,
        "xtick.color": "#222222",
        "ytick.color": "#222222",
        #（可选）刻度线也稍微加粗，视觉更统一
        "xtick.major.width": 1.2,
        "ytick.major.width": 1.2,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    })
_apply_style()


# ========= Small utilities =========
def mape(y: Sequence[float], yhat: Sequence[float], eps: float = 1e-12) -> float:
    y = np.asarray(y, float)
    yhat = np.asarray(yhat, float)
    mask = np.isfinite(y) & np.isfinite(yhat)
    if mask.sum() == 0:
        return np.nan
    return float(np.mean(np.abs(y[mask] - yhat[mask]) / np.maximum(eps, np.abs(y[mask])))) * 100.0


def k_star(y: Sequence[float], delta: float = 0.01) -> int:
    """Smallest k such that L(k)-L(k+1) <= delta * L(k); if none, return max k."""
    y = np.asarray(y, float)
    K = len(y)
    ks = K
    for i in range(K - 1):
        if np.isfinite(y[i]) and np.isfinite(y[i + 1]):
            if (y[i] - y[i + 1]) <= delta * y[i]:
                ks = i + 1
                break
    return ks


def _best_b_safe(k_vals: Sequence[float], y: Sequence[float], b_grid: np.ndarray = np.linspace(0, 1, 41)) -> Tuple[float, Optional[float]]:
    """For possibly NaN y, search b minimizing SSE; return (b, sse) or (0.0, None) on failure."""
    y = np.asarray(y, float)
    mask = np.isfinite(y)
    if mask.sum() < 2:
        return 0.0, None
    k = np.asarray(k_vals, float)[mask]
    yy = y[mask]
    best_b, best_sse = None, np.inf
    for b in b_grid:
        x = 1.0 / (k + b)
        X = np.vstack([np.ones_like(x), x]).T
        coef, _, _, _ = np.linalg.lstsq(X, yy, rcond=None)
        yhat = X @ coef
        resid = yy - yhat
        sse = float(np.sum(resid ** 2))
        if np.isfinite(sse) and (sse < best_sse):
            best_sse, best_b = sse, float(b)
    if best_b is None:
        return 0.0, None
    return best_b, best_sse


def _three_point_forecast_safe(k_vals: Sequence[float], y: Sequence[float], b: float, prefer_idx=(0, 1, 3)) -> Optional[np.ndarray]:
    """Prefer k=1,2,4; fall back to at least 2 points fit."""
    y = np.asarray(y, float)
    k_vals = np.asarray(k_vals, float)
    finite = np.isfinite(y)

    pick = [i for i in prefer_idx if (i < len(y) and finite[i])]
    if len(pick) < 2:
        extra = [i for i in range(len(y)) if finite[i] and i not in pick]
        pick = (pick + extra)[:2]
    elif len(pick) < 3:
        extra = [i for i in range(len(y)) if finite[i] and i not in pick]
        pick = (pick + extra)[:3]

    if len(pick) < 2:
        return None

    ke = k_vals[pick]
    ye = y[pick]
    x = 1.0 / (ke + b)
    Xe = np.vstack([np.ones_like(x), x]).T
    coef, _, _, _ = np.linalg.lstsq(Xe, ye, rcond=None)
    yhat = coef[0] + coef[1] / (k_vals + b)
    return yhat


# ========= Loader for cross-domain workbook =========
def _parse_blocked_mean_sheet(df: pd.DataFrame, verbose: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if not isinstance(df.columns, pd.MultiIndex):
        raise ValueError("Expect MultiIndex columns on 'mean' sheet.")
    methods = list(dict.fromkeys([lvl0 for (lvl0, _) in df.columns]))
    rows_macro, rows_domain = [], []

    for m in methods:
        sub = df[m].copy()
        if sub.shape[1] < 10:
            raise ValueError(f"Method {m}: expect 10 cols (label + 9 k-values).")
        kcols = [sub.columns[i] for i in range(1, 10)]
        first_col = sub.columns[0]

        curr_N = np.nan
        m0 = re.match(r'^(\d+(?:\.\d+)?)B', str(first_col))
        if m0: curr_N = float(m0.group(1))

        block_rows = []
        for ridx, label in enumerate(sub[first_col].tolist()):
            if isinstance(label, str) and re.match(r'^\d+(?:\.\d+)?B$', label):
                if block_rows:
                    idxs = [r for (r, _) in block_rows]
                    num = sub.loc[idxs, kcols].apply(pd.to_numeric, errors='coerce')
                    means = num.mean(axis=0)
                    rowM = {"N": curr_N, "Method": m}
                    for i, col in enumerate(kcols, start=1):
                        rowM[f"k{i}"] = float(means[col])
                    rows_macro.append(rowM)
                    for (r, lab) in block_rows:
                        vals = pd.to_numeric(sub.loc[r, kcols], errors='coerce').to_numpy(dtype=float)
                        if np.all(np.isnan(vals)): continue
                        rowD = {"N": curr_N, "Method": m, "domain": str(lab).strip()}
                        for i, v in enumerate(vals, start=1):
                            rowD[f"k{i}"] = float(v)
                        rows_domain.append(rowD)
                    block_rows = []
                curr_N = float(re.match(r'^(\d+(?:\.\d+)?)B$', label).group(1))
                continue
            if isinstance(label, str) and "domain" in label.lower():
                continue
            block_rows.append((ridx, label))

        if block_rows:
            idxs = [r for (r, _) in block_rows]
            num = sub.loc[idxs, kcols].apply(pd.to_numeric, errors='coerce')
            means = num.mean(axis=0)
            rowM = {"N": curr_N, "Method": m}
            for i, col in enumerate(kcols, start=1):
                rowM[f"k{i}"] = float(means[col])
            rows_macro.append(rowM)
            for (r, lab) in block_rows:
                vals = pd.to_numeric(sub.loc[r, kcols], errors='coerce').to_numpy(dtype=float)
                if np.all(np.isnan(vals)): continue
                rowD = {"N": curr_N, "Method": m, "domain": str(lab).strip()}
                for i, v in enumerate(vals, start=1):
                    rowD[f"k{i}"] = float(v)
                rows_domain.append(rowD)

    tidy_macro = pd.DataFrame(rows_macro)
    tidy_domain = pd.DataFrame(rows_domain)
    if verbose and not tidy_macro.empty:
        got = tidy_macro.groupby("Method")["N"].apply(lambda x: sorted(set(x)))
        print("[mean] Parsed (Method → Ns):")
        for m, Ns in got.items():
            print(f"  - {m}: N={Ns}")
    return tidy_macro, tidy_domain


def load_cross_methods(path_xlsx: str = "all_methods.xlsx", verbose: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
    xls = pd.ExcelFile(path_xlsx)
    if "mean" not in xls.sheet_names:
        raise ValueError("Sheet 'mean' not found.")
    df_mean = pd.read_excel(xls, sheet_name="mean", header=[0, 1])
    tidy_macro, tidy_domain = _parse_blocked_mean_sheet(df_mean, verbose=verbose)
    for tdf in (tidy_macro, tidy_domain):
        if "Method" in tdf.columns:
            tdf["Method"] = tdf["Method"].astype(str).str.strip()
    if tidy_macro.empty:
        raise ValueError("No records (N,Method,k1..k9) parsed from 'mean'.")
    return tidy_macro, tidy_domain


def plot_side_by_side(xlsx: str, out_png: str, out_csv: str, delta: float = 0.01, method_display_map: Optional[dict] = None) -> None:
    tidy_macro, _ = load_cross_methods(xlsx)

    if method_display_map is None:
        method_display_map = {"Average merge": "Average", "TA merge": "TA", "TIES merge": "TIES", "DARE merge": "DARE"}
    def disp(m: str) -> str: return method_display_map.get(m, m)

    k_vals = np.arange(1, 10)
    value_cols = [f"k{i}" for i in range(1, 10)]

    series_keys = sorted(list({(m, float(N)) for (m, N) in tidy_macro.groupby(["Method", "N"]).groups.keys()}),
                         key=lambda t: (t[0], t[1]))
    cmap = plt.get_cmap("tab20", len(series_keys))
    color_map = {series_keys[i]: cmap(i) for i in range(len(series_keys))}

    fig, axes = plt.subplots(1, 2, figsize=(14.5, 6), sharey=True)
    axL, axR = axes

    # Left: real curves
    for (m, N), dfm in tidy_macro.groupby(["Method", "N"]):
        Nf = float(N)
        y = dfm[value_cols].astype(float).mean(axis=0).to_numpy()
        axL.plot(k_vals, y, color=color_map[(m, Nf)], linestyle="-", alpha=0.95)
    axL.set_xlabel("k (experts)", fontsize=16)
    axL.set_ylabel("Cross-domain CE (macro avg)", fontsize=16)
    axL.set_title("Real data", fontsize=19)

    # Right: 3-point fits
    rows_summary: List[dict] = []
    for (m, N), dfm in tidy_macro.groupby(["Method", "N"]):
        Nf = float(N)
        y = dfm[value_cols].astype(float).mean(axis=0).to_numpy()
        best_b, _ = _best_b_safe(k_vals, y, b_grid=np.linspace(0, 1, 41))
        yhat_e = _three_point_forecast_safe(k_vals, y, best_b)
        if yhat_e is not None:
            axR.plot(k_vals, yhat_e, color=color_map[(m, Nf)], linestyle="--", alpha=0.95)
            err_k9 = abs(y[-1] - yhat_e[-1]) if np.isfinite(y[-1]) else np.nan
            mape_k = mape(y, yhat_e)
        else:
            err_k9 = np.nan
            mape_k = np.nan
        rows_summary.append({
            "series": f"{m}@{Nf}B",
            "method": disp(m),
            "N": Nf,
            "b": best_b,
            "abs_err_k9": err_k9,
            "MAPE_%": mape_k,
            "k_star(1%)": k_star(y, delta=delta)
        })

    axR.set_xlabel("k (experts)", fontsize=16)
    axR.set_title("3-point forecast", fontsize=19)

    for ax in axes:
        ax.grid(True, which="major", axis="y")
        ax.set_xlim(k_vals.min(), k_vals.max())

    style_handles = [
        Line2D([0], [0], color="#444444", linestyle="-", lw=2, label="Real"),
        Line2D([0], [0], color="#444444", linestyle="--", lw=2, label="Fit"),
    ]
    series_handles = [Line2D([0], [0], color=color_map[key], lw=3, label=f"{disp(key[0])}@{key[1]:g}B")
                      for key in series_keys]
    leg1 = axR.legend(handles=style_handles, loc="upper right", fontsize=14)
    axR.add_artist(leg1)
    fig.legend(handles=series_handles, loc="lower center", ncol=8,
               bbox_to_anchor=(0.5, -0.08), fontsize=11)
    plt.tight_layout(rect=[0, 0.04, 1, 1])
    plt.savefig(out_png, dpi=220, bbox_inches="tight")
    plt.close()

    pd.DataFrame(rows_summary).to_csv(out_csv, index=False)


def main():
    plot_side_by_side("all_methods.xlsx", "figs/rq5_all_data_vs_fits.png", "out/rq5_cross_domain.csv", delta=0.01)

if __name__ == "__main__":
    main()