import argparse
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import matplotlib
import matplotlib.pyplot as plt
import numpy as np


# 仅画图：离线使用（不需要 torch/transformers）
matplotlib.use("Agg")
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
plt.rcParams["pdf.use14corefonts"] = False
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["axes.linewidth"] = 1.8


@dataclass
class CacheBlob:
    indices: np.ndarray  # (K,)
    mean_eos_prob: np.ndarray  # (K,)
    counts: np.ndarray  # (K,)
    meta: Dict[str, Any]


def _mkdir_parent(path: Optional[str]) -> None:
    if not path:
        return
    parent = os.path.dirname(path)
    if parent:
        os.makedirs(parent, exist_ok=True)


def _parse_meta(meta_arr: Any) -> Dict[str, Any]:
    """
    与 `plot_eos_prob.py` 的 save_cache/load_cache 对齐：
    - meta 在 npz 里通常是 json.dumps 的字符串（0-d array）
    - 也兼容 bytes / object / 非 json 的情况
    """
    try:
        meta_raw = meta_arr.item() if getattr(meta_arr, "shape", None) == () else str(meta_arr)
    except Exception:
        meta_raw = meta_arr

    if isinstance(meta_raw, (bytes, bytearray)):
        try:
            meta_raw = meta_raw.decode("utf-8")
        except Exception:
            meta_raw = str(meta_raw)

    if isinstance(meta_raw, str):
        try:
            parsed = json.loads(meta_raw)
            return parsed if isinstance(parsed, dict) else {}
        except Exception:
            return {}
    return {}


def load_cache(cache_path: str) -> CacheBlob:
    z = np.load(cache_path, allow_pickle=True)
    meta = _parse_meta(z.get("meta"))
    return CacheBlob(
        indices=z["indices"],
        mean_eos_prob=z["mean_eos_prob"],
        counts=z["counts"],
        meta=meta,
    )


def plot_curves(blobs: List[CacheBlob], labels: List[str], save_path: str, log_y: bool) -> None:
    BG = "#FAFAFA"
    SPINE_COLOR = "#666666"
    GRID_COLOR = "#CCCCCC"
    TICK_COLOR = "#333333"
    # 与 `plot_length_ngram.py` 一致的色系（蓝/紫/红为主，橙/绿为备用）
    COLORS = ["#00468B", "#9B59B6", "#AE1029", "#FF7F0E", "#2CA02C"]

    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    ax.set_facecolor(BG)
    for side in ["left", "bottom", "top", "right"]:
        ax.spines[side].set_linewidth(1.8)
        ax.spines[side].set_color(SPINE_COLOR)
        ax.spines[side].set_visible(True)
    ax.grid(True, axis="both", alpha=0.3, color=GRID_COLOR, linewidth=0.8, linestyle="-", zorder=0)
    ax.set_axisbelow(True)
    ax.tick_params(axis="both", labelcolor=TICK_COLOR, labelsize=20, length=6, width=1.5)

    # 口径/数值空间：以第一个 blob 为准（通常多条曲线来源一致）
    metric = str(blobs[0].meta.get("metric", "cumulative_product")) if blobs else "cumulative_product"
    value_space = str(blobs[0].meta.get("value_space", "prob")) if blobs else "prob"
    # 需求：简单改成 exp 的 y —— 如果缓存里是 ln(prob)，则先 exp 回到 prob 再画
    exp_y = value_space == "ln"

    if metric == "stop_at_t":
        ylab_base = "P(stop at step=t) (mean)"
    else:
        ylab_base = "P(end by step=t)"

    all_ys: List[np.ndarray] = []  # 用于设置 y 轴范围（prob 空间）
    for i, (b, lab) in enumerate(zip(blobs, labels)):
        xs = np.asarray(b.indices)
        ys = np.asarray(b.mean_eos_prob)
        if exp_y:
            ys = np.exp(ys)
        all_ys.append(ys)
        ax.plot(xs, ys, linewidth=3.0, color=COLORS[i % len(COLORS)], label=f"{lab}", zorder=4)

    ax.set_xlabel("Token index (step)", fontsize=26, fontweight="bold", labelpad=10)
    ax.set_ylabel(
        # r"$\mathrm{mean}(\log \mathrm{stop}_i(t))$",
        ylab_base,
        fontsize=26,
        fontweight="bold",
        labelpad=10,
    )

    # y 轴处理：现在统一在 prob 空间（若原来是 ln，则已 exp）
    if log_y:
        ax.set_yscale("log")
        ys_pos = np.concatenate([y[y > 0] for y in all_ys if np.any(y > 0)]) if all_ys else np.array([])
        bottom = float(np.nanmin(ys_pos)) if ys_pos.size else 1e-12
        ax.set_ylim(bottom=max(1e-12, bottom))
    else:
        ax.set_ylim(bottom=0.0)

    # # 右上角：显示第一个 blob 的关键信息（通常各条曲线口径一致）
    # if blobs:
    #     b0 = blobs[0]
    #     ax.text(
    #         0.98,
    #         0.96,
    #         f"metric={str(b0.meta.get('metric', ''))}\n"
    #         f"value_space={str(b0.meta.get('value_space', ''))}\n"
    #         f"n_samples(step=0)={int(b0.counts[0]) if len(b0.counts) else -1}\n"
    #         f"max_step={int(b0.meta.get('max_step', -1))}\n"
    #         f"stride={int(b0.meta.get('stride', -1))}",
    #         transform=ax.transAxes,
    #         va="top",
    #         ha="right",
    #         fontsize=14,
    #         bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.95, edgecolor="#999999", linewidth=1.2),
    #         family="monospace",
    #     )

    _mkdir_parent(save_path)
    # legend 放底部（与 `plot_length_ngram.py` 风格对齐）
    handles, labs = ax.get_legend_handles_labels()
    if handles:
        ncol = min(3, len(handles))
        legend = fig.legend(
            handles,
            labs,
            loc="lower center",
            bbox_to_anchor=(0.5, 0.0),  # 放在图内底部，避免 bbox_inches='tight' 裁切
            ncol=ncol,
            frameon=True,
            framealpha=0.95,
            edgecolor="#888888",
            fancybox=True,
            shadow=False,
            prop={"weight": "bold", "size": 20},
        )
        legend.get_frame().set_linewidth(1.5)

    # 给底部 legend 预留空间
    plt.tight_layout(rect=(0.0, 0.08, 1.0, 1.0))
    plt.savefig(
        save_path,
        dpi=600,
        bbox_inches="tight",
        format="pdf",
        metadata={"Creator": "matplotlib", "Producer": "matplotlib"},
    )
    plt.close(fig)
    print(f"已保存图像: {save_path}")


def main() -> None:
    ap = argparse.ArgumentParser(description="Plot EOS prob curve(s) from npz cache (no model forward).")
    ap.add_argument("--cache", type=str, nargs="+", required=True, help="one or more .npz cache files")
    ap.add_argument("--label", type=str, nargs="*", default=None, help="labels for each cache (default=basename)")
    ap.add_argument("--name", type=str, default="run", help="plot title prefix")
    ap.add_argument("--out", type=str, default="eos_prob_curve.pdf", help="output pdf path")
    ap.add_argument("--log-y", action="store_true", help="use log y-axis")
    args = ap.parse_args()

    cache_paths = [str(p) for p in args.cache]
    blobs = [load_cache(p) for p in cache_paths]
    if args.label is None or len(args.label) == 0:
        labels = [os.path.basename(p) for p in cache_paths]
    else:
        if len(args.label) != len(cache_paths):
            raise SystemExit(f"--label 个数需与 --cache 一致：{len(args.label)} vs {len(cache_paths)}")
        labels = list(args.label)

    plot_curves(blobs, labels=labels, save_path=str(args.out), log_y=bool(args.log_y))


if __name__ == "__main__":
    main()

