# -*- coding: utf-8 -*-
"""
File: plot_rq5_mape_by_domain.py
Purpose: Generate the RQ5 per-domain MAPE boxplot (ALL + each domain).
Inputs:
  - all_methods.xlsx with a 'mean' sheet as described.
Outputs:
  - figs/rq5_mape_box_domains.png
Notes:
  - Pure Matplotlib (no seaborn). One figure (single axes).
  - This file is **independent**; it duplicates loader/utilities so it can run standalone.
"""
from __future__ import annotations

import os
import re
import argparse
from typing import List, Optional, Sequence, Tuple, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# ========= I/O dirs =========
os.makedirs("figs", exist_ok=True)
os.makedirs("out", exist_ok=True)


# ========= Light Matplotlib styling =========
def _apply_style() -> None:
    plt.rcParams.update({
        "figure.dpi": 160,
        "savefig.dpi": 220,
        "figure.facecolor": "white",
        "axes.facecolor": "white",
        "axes.edgecolor": "#222222",
        "axes.labelcolor": "#222222",
        # ---- 关键修改：不要在全局打开网格，避免出现竖线 ----
        "axes.grid": False,
        # ---- 关键修改：加粗坐标轴框线（spines）----
        "axes.linewidth": 3.0,
        "grid.color": "#cccccc",
        "grid.linestyle": "-",
        "grid.linewidth": 0.8,
        "grid.alpha": 0.6,
        "axes.spines.top": False,
        "axes.spines.right": False,
        "lines.linewidth": 2.0,
        "legend.frameon": True,
        "legend.fontsize": 12,
        "font.size": 11,
        "xtick.color": "#222222",
        "ytick.color": "#222222",
        #（可选）刻度线也稍微加粗，视觉更统一
        "xtick.major.width": 1.2,
        "ytick.major.width": 1.2,
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
    })
_apply_style()

# ========= Small utilities =========
def mape(y: Sequence[float], yhat: Sequence[float], eps: float = 1e-12) -> float:
    y = np.asarray(y, float)
    yhat = np.asarray(yhat, float)
    mask = np.isfinite(y) & np.isfinite(yhat)
    if mask.sum() == 0:
        return np.nan
    return float(np.mean(np.abs(y[mask] - yhat[mask]) / np.maximum(eps, np.abs(y[mask])))) * 100.0


def _best_b_safe(k_vals: Sequence[float], y: Sequence[float], b_grid: np.ndarray = np.linspace(0, 1, 41)) -> Tuple[float, Optional[float]]:
    y = np.asarray(y, float)
    mask = np.isfinite(y)
    if mask.sum() < 2:
        return 0.0, None
    k = np.asarray(k_vals, float)[mask]
    yy = y[mask]
    best_b, best_sse = None, np.inf
    for b in b_grid:
        x = 1.0 / (k + b)
        X = np.vstack([np.ones_like(x), x]).T
        coef, _, _, _ = np.linalg.lstsq(X, yy, rcond=None)
        yhat = X @ coef
        resid = yy - yhat
        sse = float(np.sum(resid ** 2))
        if np.isfinite(sse) and (sse < best_sse):
            best_sse, best_b = sse, float(b)
    if best_b is None:
        return 0.0, None
    return best_b, best_sse


def _three_point_forecast_safe(k_vals: Sequence[float], y: Sequence[float], b: float, prefer_idx=(0, 1, 3)) -> Optional[np.ndarray]:
    y = np.asarray(y, float)
    k_vals = np.asarray(k_vals, float)
    finite = np.isfinite(y)

    pick = [i for i in prefer_idx if (i < len(y) and finite[i])]
    if len(pick) < 2:
        extra = [i for i in range(len(y)) if finite[i] and i not in pick]
        pick = (pick + extra)[:2]
    elif len(pick) < 3:
        extra = [i for i in range(len(y)) if finite[i] and i not in pick]
        pick = (pick + extra)[:3]

    if len(pick) < 2:
        return None

    ke = k_vals[pick]
    ye = y[pick]
    x = 1.0 / (ke + b)
    Xe = np.vstack([np.ones_like(x), x]).T
    coef, _, _, _ = np.linalg.lstsq(Xe, ye, rcond=None)
    yhat = coef[0] + coef[1] / (k_vals + b)
    return yhat


# ========= Loader =========
def _parse_blocked_mean_sheet(df: pd.DataFrame, verbose: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
    if not isinstance(df.columns, pd.MultiIndex):
        raise ValueError("Expect MultiIndex columns on 'mean' sheet.")
    methods = list(dict.fromkeys([lvl0 for (lvl0, _) in df.columns]))
    rows_macro, rows_domain = [], []

    for m in methods:
        sub = df[m].copy()
        if sub.shape[1] < 10:
            raise ValueError(f"Method {m}: expect 10 cols (label + 9 k-values).")
        kcols = [sub.columns[i] for i in range(1, 10)]
        first_col = sub.columns[0]

        curr_N = np.nan
        m0 = re.match(r'^(\d+(?:\.\d+)?)B', str(first_col))
        if m0: curr_N = float(m0.group(1))

        block_rows = []
        for ridx, label in enumerate(sub[first_col].tolist()):
            if isinstance(label, str) and re.match(r'^\d+(?:\.\d+)?B$', label):
                if block_rows:
                    idxs = [r for (r, _) in block_rows]
                    num = sub.loc[idxs, kcols].apply(pd.to_numeric, errors='coerce')
                    means = num.mean(axis=0)
                    rowM = {"N": curr_N, "Method": m}
                    for i, col in enumerate(kcols, start=1):
                        rowM[f"k{i}"] = float(means[col])
                    rows_macro.append(rowM)
                    for (r, lab) in block_rows:
                        vals = pd.to_numeric(sub.loc[r, kcols], errors='coerce').to_numpy(dtype=float)
                        if np.all(np.isnan(vals)): continue
                        rowD = {"N": curr_N, "Method": m, "domain": str(lab).strip()}
                        for i, v in enumerate(vals, start=1):
                            rowD[f"k{i}"] = float(v)
                        rows_domain.append(rowD)
                    block_rows = []
                curr_N = float(re.match(r'^(\d+(?:\.\d+)?)B$', label).group(1))
                continue
            if isinstance(label, str) and "domain" in label.lower():
                continue
            block_rows.append((ridx, label))

        if block_rows:
            idxs = [r for (r, _) in block_rows]
            num = sub.loc[idxs, kcols].apply(pd.to_numeric, errors='coerce')
            means = num.mean(axis=0)
            rowM = {"N": curr_N, "Method": m}
            for i, col in enumerate(kcols, start=1):
                rowM[f"k{i}"] = float(means[col])
            rows_macro.append(rowM)
            for (r, lab) in block_rows:
                vals = pd.to_numeric(sub.loc[r, kcols], errors='coerce').to_numpy(dtype=float)
                if np.all(np.isnan(vals)): continue
                rowD = {"N": curr_N, "Method": m, "domain": str(lab).strip()}
                for i, v in enumerate(vals, start=1):
                    rowD[f"k{i}"] = float(v)
                rows_domain.append(rowD)

    tidy_macro = pd.DataFrame(rows_macro)
    tidy_domain = pd.DataFrame(rows_domain)
    if verbose and not tidy_macro.empty:
        got = tidy_macro.groupby("Method")["N"].apply(lambda x: sorted(set(x)))
        print("[mean] Parsed (Method → Ns):")
        for m, Ns in got.items():
            print(f"  - {m}: N={Ns}")
    return tidy_macro, tidy_domain


def load_cross_methods(path_xlsx: str = "all_methods.xlsx", verbose: bool = True) -> Tuple[pd.DataFrame, pd.DataFrame]:
    xls = pd.ExcelFile(path_xlsx)
    if "mean" not in xls.sheet_names:
        raise ValueError("Sheet 'mean' not found.")
    df_mean = pd.read_excel(xls, sheet_name="mean", header=[0, 1])
    tidy_macro, tidy_domain = _parse_blocked_mean_sheet(df_mean, verbose=verbose)
    for tdf in (tidy_macro, tidy_domain):
        if "Method" in tdf.columns:
            tdf["Method"] = tdf["Method"].astype(str).str.strip()
    if tidy_macro.empty:
        raise ValueError("No records (N,Method,k1..k9) parsed from 'mean'.")
    return tidy_macro, tidy_domain


def plot_mape_by_domain(xlsx: str, out_png: str) -> None:
    tidy_macro, tidy_domain = load_cross_methods(xlsx, verbose=False)

    k_vals = np.arange(1, 10)
    value_cols = [f"k{i}" for i in range(1, 10)]

    # ALL (macro) distribution of MAPE across methods & Ns
    macro_mape_vals: List[float] = []
    for (m, N), dfm in tidy_macro.groupby(["Method", "N"]):
        y = dfm[value_cols].astype(float).mean(axis=0).to_numpy()
        best_b, _ = _best_b_safe(k_vals, y, b_grid=np.linspace(0, 1, 41))
        yhat_e = _three_point_forecast_safe(k_vals, y, best_b)
        if yhat_e is not None:
            macro_mape_vals.append(mape(y, yhat_e))

    # Per-domain distribution
    domain_mape: Dict[str, List[float]] = {}
    for (m, N, d), sub in tidy_domain.groupby(["Method", "N", "domain"]):
        y = sub[value_cols].astype(float).mean(axis=0).to_numpy()
        best_b, _ = _best_b_safe(k_vals, y, b_grid=np.linspace(0, 1, 41))
        yhat_e = _three_point_forecast_safe(k_vals, y, best_b)
        err = mape(y, yhat_e) if yhat_e is not None else np.nan
        if np.isfinite(err):
            domain_mape.setdefault(d, []).append(err)

    med_sorted = sorted([(d, np.median(v)) for d, v in domain_mape.items()], key=lambda t: t[1])
    labels = ["ALL"] + [d for d, _ in med_sorted]
    groups = [[x for x in macro_mape_vals if np.isfinite(x)]] + [domain_mape[d] for d, _ in med_sorted]

    fig = plt.figure(figsize=(max(10, 0.8 * len(labels)), 5.8))
    ax = plt.gca()
    bp = ax.boxplot(
        groups,
        tick_labels=labels,
        notch=True,
        patch_artist=True,
        showmeans=True,
        showfliers=False,
        meanprops=dict(marker="D", markersize=6, markerfacecolor="#333333", markeredgecolor="white"),
        medianprops=dict(color="#333333", linewidth=2.5),
        whiskerprops=dict(color="#777777"),
        capprops=dict(color="#777777"),
    )

    cmap_boxes = plt.get_cmap("PuBu", len(labels))
    for i, patch in enumerate(bp["boxes"]):
        patch.set_facecolor(cmap_boxes(i))
        patch.set_alpha(0.6)
        patch.set_edgecolor("#555555")
        patch.set_linewidth(1.5)

    ax.yaxis.grid(True, linestyle="--", linewidth=0.6, alpha=0.6)
    ax.set_axisbelow(True)
    plt.xticks(rotation=25, fontsize=14)
    ax.set_ylabel("MAPE across k (%)", fontsize=18)
    ax.set_title("Early-3pt forecast error by domain (ALL + per-domain)", fontsize=20)

    if len(groups) > 0:
        ymin, ymax = ax.get_ylim()
        offset = 0.02 * (ymax - ymin)
        for i, g in enumerate(groups, start=1):
            if len(g) == 0:
                continue
            med = float(np.median(g))
            ax.text(i, med + offset, f"{med:.1f}", ha="center", va="bottom", fontsize=12, color="#333333")

    plt.tight_layout()
    plt.savefig(out_png, dpi=220)
    plt.close()


def main():
    plot_mape_by_domain("all_methods.xlsx", "figs/rq5_mape_box_domains.png")


if __name__ == "__main__":
    main()