import argparse
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch


# 与 plot_length.py 对齐：离线画图
matplotlib.use("Agg")
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
plt.rcParams["pdf.use14corefonts"] = False
plt.rcParams["font.sans-serif"] = ["Arial", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False


def load_data(file_path: str) -> List[Dict[str, Any]]:
    data: List[Dict[str, Any]] = []
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            data.append(json.loads(line))
    return data


def _mkdir_parent(path: Optional[str]) -> None:
    if not path:
        return
    parent = os.path.dirname(path)
    if parent:
        os.makedirs(parent, exist_ok=True)


def _pick_first(d: Dict[str, Any], keys: List[str]) -> Any:
    for k in keys:
        if k in d and d[k] is not None:
            return d[k]
    return None


def _to_prompt_text(prompt_obj: Any, tokenizer) -> Optional[str]:
    if prompt_obj is None:
        return None
    if isinstance(prompt_obj, str):
        return prompt_obj
    # 常见：messages/list[dict]
    if isinstance(prompt_obj, list) and hasattr(tokenizer, "apply_chat_template"):
        try:
            return tokenizer.apply_chat_template(prompt_obj, tokenize=False, add_generation_prompt=True)
        except Exception:
            return None
    return None


def _extract_prompt_and_response(item: Dict[str, Any], tokenizer) -> Tuple[Optional[str], Optional[str]]:
    prompt_obj = _pick_first(item, ["prompt", "input", "question", "messages", "query"])
    response = _pick_first(item, ["generated_text", "output", "response", "text", "completion"])
    prompt = _to_prompt_text(prompt_obj, tokenizer)
    if isinstance(response, dict):
        # 兼容部分结构：{"text": "..."}
        response = response.get("text")
    if response is not None and not isinstance(response, str):
        response = str(response)
    return prompt, response


@dataclass
class CacheBlob:
    indices: np.ndarray  # (K,)
    mean_eos_prob: np.ndarray  # (K,)
    counts: np.ndarray  # (K,)
    meta: Dict[str, Any]


def save_cache(cache_path: str, blob: CacheBlob) -> None:
    _mkdir_parent(cache_path)
    np.savez_compressed(
        cache_path,
        indices=blob.indices.astype(np.int32),
        mean_eos_prob=blob.mean_eos_prob.astype(np.float64),
        counts=blob.counts.astype(np.int64),
        meta=json.dumps(blob.meta, ensure_ascii=False),
    )


def load_cache(cache_path: str) -> CacheBlob:
    z = np.load(cache_path, allow_pickle=True)
    meta_raw = z["meta"].item() if z["meta"].shape == () else str(z["meta"])
    meta = json.loads(meta_raw) if isinstance(meta_raw, str) else {}
    return CacheBlob(
        indices=z["indices"],
        mean_eos_prob=z["mean_eos_prob"],
        counts=z["counts"],
        meta=meta,
    )


def _compute_eos_prob_curve(
    model,
    tokenizer,
    prompts: List[Optional[str]],
    responses: List[str],
    stride: int,
    max_resp_tokens: Optional[int],
    max_total_tokens: Optional[int],
    batch_size: int,
    amp_dtype: str,
    metric: str,
) -> CacheBlob:
    assert len(prompts) == len(responses)
    eos_id = tokenizer.eos_token_id
    if eos_id is None:
        raise RuntimeError("tokenizer.eos_token_id 为 None，无法计算 EOS prob。")

    # 先 tokenize：分别得到 prompt_ids / resp_ids（按 response token index 做 stride）
    prompt_ids_list: List[List[int]] = []
    resp_ids_list: List[List[int]] = []
    resp_lens: List[int] = []
    dropped_by_len = 0
    total_before = 0

    for p, r in zip(prompts, responses):
        total_before += 1
        p_text = "" if p is None else p
        enc_p = tokenizer(p_text, add_special_tokens=False, return_attention_mask=False, return_tensors=None)
        enc_r = tokenizer(r, add_special_tokens=False, return_attention_mask=False, return_tensors=None)
        p_ids = list(enc_p["input_ids"])
        r_ids = list(enc_r["input_ids"])
        if max_resp_tokens is not None:
            r_ids = r_ids[: int(max_resp_tokens)]

        # 先按长度过滤（默认 16k）：只保留超过阈值的长样本（total_len > max_total_tokens）
        if max_total_tokens is not None:
            total_len = len(p_ids) + len(r_ids)
            if total_len < int(max_total_tokens):
                dropped_by_len += 1
                continue

        prompt_ids_list.append(p_ids)
        resp_ids_list.append(r_ids)
        resp_lens.append(len(r_ids))

    if max_total_tokens is not None:
        kept = len(resp_lens)
        print(
            f"[length_filter_keep_long] threshold={int(max_total_tokens)} (keep if > threshold) | "
            f"kept={kept}/{total_before} | dropped={dropped_by_len}"
        )

    max_step = max(resp_lens) if resp_lens else 0
    stride = max(1, int(stride))
    indices = np.arange(0, max_step + 1, stride, dtype=np.int32)  # step=0..max_step

    # 两种口径：
    # - stop_at_t（PMF）: stop(t) = Π_{k=0..t-1}(1-p_k) * p_t，返回 ln(stop(t)) 的均值（数值更稳定）
    # - stop_by_t（CDF）: stop_by(t) = 1 - Π_{k=0..t}(1-p_k)
    metric = str(metric)
    if metric not in ("stop_at_t", "stop_by_t"):
        raise ValueError(f"不支持的 metric={metric}，请使用 stop_at_t 或 stop_by_t。")

    sum_logy = np.zeros((len(indices),), dtype=np.float64)  # 仅 stop_at_t 使用（存 ln 值求均值）
    sum_y = np.zeros((len(indices),), dtype=np.float64)  # 仅 stop_by_t 使用（存 prob 值求均值）
    cnt = np.zeros((len(indices),), dtype=np.int64)
    # 定义：
    #   p_eos(t) = P(EOS | prompt + resp[:t])（下一 token 为 EOS 的概率）
    # 重要：stride 仅用于画图抽点，不影响口径（内部仍会逐位置计算）
    eps = 1e-30  # 防止 log(0)
    min_log = float(np.log(eps))
    chunk_size = 32768  # KV-cache 流式时每次喂入的 token 数（可按需改大/改小）

    device = next(model.parameters()).device
    use_amp = amp_dtype in ("fp16", "bf16")
    autocast_dtype = torch.float16 if amp_dtype == "fp16" else torch.bfloat16

    # tqdm 可选：没有安装也能跑
    try:
        from tqdm import tqdm as _tqdm  # type: ignore
    except Exception:  # pragma: no cover
        _tqdm = None

    model.eval()
    # 用 KV-cache 流式：每条样本先喂 prompt 得到 past，然后逐 token 喂 response，
    # 每一步只拿 [1, V]（或 [m, V]）的 logits 计算 p_eos(t)，并在 log 空间累加得到 stop(t)。
    # 这样能严格覆盖所有位置，同时避免一次性构造 [B, L, V] 大 logits（易 OOM）。
    with torch.inference_mode():
        index_list = indices.tolist()
        t_to_ki = {int(t): int(ki) for ki, t in enumerate(index_list)}

        sample_iter = range(len(resp_lens))
        if _tqdm is not None:
            sample_iter = _tqdm(sample_iter, desc="EOS prob samples", total=len(resp_lens))

        for i in sample_iter:
            p_ids = prompt_ids_list[i]
            r_ids = resp_ids_list[i]

            # t=0 需要 prompt 至少有 1 个 token；否则用 bos/eos 兜底
            if len(p_ids) == 0:
                bos_id = tokenizer.bos_token_id
                p_ids = [int(bos_id) if bos_id is not None else int(eos_id)]

            # ---- step 0: context = prompt ----
            prompt = torch.tensor([p_ids], dtype=torch.long, device=device)
            attn = torch.ones_like(prompt, dtype=torch.long, device=device)
            pos = torch.arange(prompt.shape[1], device=device, dtype=torch.long).unsqueeze(0)
            with torch.autocast(device_type=str(device.type), dtype=autocast_dtype, enabled=use_amp):
                out = model(input_ids=prompt, attention_mask=attn, position_ids=pos, use_cache=True)

            logits_last = out.logits[:, -1, :]  # [1, V]
            logits_last_f = logits_last.float()
            log_p0 = (logits_last_f[:, eos_id] - torch.logsumexp(logits_last_f, dim=-1)).clamp_min(min_log)  # [1]
            # log_surv 表示“到当前 step 为止仍未结束”的 log 概率：surv(t)=Π_{k=0..t}(1-p_k)
            p0 = log_p0.exp().clamp(min=0.0, max=1.0 - 1e-12)  # [1]
            log_surv = torch.log1p(-p0).clamp_min(min_log)  # [1]

            ki0 = t_to_ki.get(0)
            if ki0 is not None:
                if metric == "stop_at_t":
                    # stop(0) = p0（在 step=0 直接结束）
                    sum_logy[ki0] += float(log_p0.item())
                    cnt[ki0] += 1
                else:
                    # stop_by(0) = 1 - (1-p0) = p0
                    cdf0 = float((-torch.expm1(log_surv)).clamp(0.0, 1.0).item())
                    sum_y[ki0] += cdf0
                    cnt[ki0] += 1

            past = out.past_key_values

            # ---- steps 1..resp_len: context = prompt + resp[:t] ----
            t_global = 0  # 已经喂入了多少个 response token（对应 step）
            for rs in range(0, len(r_ids), chunk_size):
                block = r_ids[rs : rs + chunk_size]
                m = len(block)
                inp = torch.tensor([block], dtype=torch.long, device=device)  # [1, m]
                past_len = int(past[0][0].shape[-2]) if past is not None else 0
                pos_t = torch.arange(past_len, past_len + m, device=device, dtype=torch.long).unsqueeze(0)  # [1, m]

                with torch.autocast(device_type=str(device.type), dtype=autocast_dtype, enabled=use_amp):
                    out = model(input_ids=inp, past_key_values=past, position_ids=pos_t, use_cache=True)
                past = out.past_key_values

                logits_block = out.logits[0]  # [m, V]
                lb = logits_block.float()
                log_p_block = (lb[:, eos_id] - torch.logsumexp(lb, dim=-1)).clamp_min(min_log)  # [m]
                p_block = log_p_block.exp().clamp(min=0.0, max=1.0 - 1e-12)  # [m]
                log_q_block = torch.log1p(-p_block).clamp_min(min_log)  # [m]
                prefix_log_q = torch.cumsum(log_q_block, dim=0)  # [m]
                zero = torch.zeros((1,), device=lb.device, dtype=lb.dtype)
                log_surv_prev = log_surv + torch.cat([zero, prefix_log_q[:-1]], dim=0)  # [m]
                # 在 block 内：
                # - stop_at_t: stop(j)=surv_prev(j)*p_j
                # - stop_by_t: cdf_after(j)=1 - surv_after(j)，其中 surv_after(j)=surv_before_block*Π_{u=0..j}(1-p_u)
                log_stop_block = (log_surv_prev + log_p_block).clamp_min(min_log)  # [m]
                log_surv_after_block = (log_surv + prefix_log_q).clamp_min(min_log)  # [m]
                cdf_block = (-torch.expm1(log_surv_after_block)).clamp(0.0, 1.0)  # [m]

                # stride 抽点记录，但口径已严格覆盖 block 内所有位置 step=t_global+1..t_global+m
                for j in range(m):
                    t = t_global + j + 1
                    ki = t_to_ki.get(int(t))
                    if ki is not None:
                        if metric == "stop_at_t":
                            sum_logy[ki] += float(log_stop_block[j].item())
                        else:
                            sum_y[ki] += float(cdf_block[j].item())
                        cnt[ki] += 1

                # 更新 survival：乘上 block 内所有 (1-p)
                log_surv = (log_surv + prefix_log_q[-1:]).clamp_min(min_log)  # keep shape [1]
                t_global += m

    if metric == "stop_at_t":
        mean = np.divide(sum_logy, np.maximum(cnt, 1), dtype=np.float64)
        value_space = "ln"
        note_metric = "stop_at_t = Π_{k<t}(1-p_k)*p_t（PMF）"
    else:
        mean = np.divide(sum_y, np.maximum(cnt, 1), dtype=np.float64)
        value_space = "prob"
        note_metric = "stop_by_t = 1-Π_{k<=t}(1-p_k) = p1+(1-p1)p2+...（CDF）"
    meta = {
        "eos_token_id": int(eos_id),
        "stride": int(stride),
        "max_step": int(max_step),
        "max_resp_tokens": None if max_resp_tokens is None else int(max_resp_tokens),
        "max_total_tokens": None if max_total_tokens is None else int(max_total_tokens),
        "batch_size": int(batch_size),
        "amp_dtype": amp_dtype,
        "n_samples": int(len(resp_lens)),
        "metric": metric,
        "value_space": value_space,
        "log_kind": "ln",
        "eps": float(eps),
        "note_stride": "stride 仅用于画图抽点；内部仍逐位置计算 step=0..t",
        "n_samples_before_len_filter": int(total_before),
        "n_dropped_by_len": int(dropped_by_len),
        "note_metric": note_metric,
        "note": "KV-cache streaming + length filter: keep samples where total_tokens(prompt+resp) > max_total_tokens",
    }
    return CacheBlob(indices=indices, mean_eos_prob=mean, counts=cnt, meta=meta)


def plot_curve(blob: CacheBlob, title: str, save_path: str, log_y: bool) -> None:
    BG = "#FAFAFA"
    COLOR = "#00468B"

    fig, ax = plt.subplots(1, 1, figsize=(12, 6))
    ax.set_facecolor(BG)
    for side in ["left", "bottom", "top", "right"]:
        ax.spines[side].set_linewidth(1.8)
        ax.spines[side].set_color("#666666")
        ax.spines[side].set_visible(True)
    ax.grid(True, axis="both", alpha=0.3, color="#CCCCCC", linewidth=0.8, linestyle="-")
    ax.set_axisbelow(True)
    ax.tick_params(axis="both", labelcolor="#333333", labelsize=14, length=6, width=1.5)

    xs = blob.indices
    ys = blob.mean_eos_prob
    metric = str(blob.meta.get("metric", "cumulative_product"))
    value_space = str(blob.meta.get("value_space", "prob"))
    if metric == "stop_at_t":
        legend = "mean P(stop at step=t)"
        ylab = "P(stop at step=t) (mean)"
    elif metric == "stop_by_t":
        legend = "mean P(stop by step=t)"
        ylab = "P(stop by step=t) (mean)"
    else:
        legend = "mean cumulative EOS prob (product)"
        ylab = "Cumulative EOS probability (mean product)"

    # 如果 cache 已经是 ln(prob)，则直接画线性 y（数值是负数）
    if value_space == "ln":
        ax.plot(xs, ys, linewidth=3.0, color=COLOR, label=f"ln {legend}")
        ax.set_ylabel(f"ln {ylab}", fontsize=18, fontweight="bold", labelpad=10)
        # y 轴范围：给一点 padding
        if len(ys) > 0 and np.isfinite(ys).any():
            ymin = float(np.nanmin(ys))
            ymax = float(np.nanmax(ys))
            pad = max(0.2, 0.05 * (ymax - ymin + 1e-9))
            ax.set_ylim(ymin - pad, ymax + pad)
    else:
        ax.plot(xs, ys, linewidth=3.0, color=COLOR, label=legend)
        ax.set_ylabel(ylab, fontsize=18, fontweight="bold", labelpad=10)
    ax.set_xlabel("Token index (step)", fontsize=18, fontweight="bold", labelpad=10)
    ax.set_title(title, fontsize=20, fontweight="bold", pad=16)
    if log_y and value_space != "ln":
        ax.set_yscale("log")
        ax.set_ylim(bottom=max(1e-12, float(np.nanmin(ys[ys > 0])) if np.any(ys > 0) else 1e-12))
    else:
        if value_space != "ln":
            ax.set_ylim(bottom=0.0)

    # 右上角 counts
    ax.text(
        0.98,
        0.96,
        f"n_samples(step=0)={int(blob.counts[0])}\nmax_step={int(blob.meta.get('max_step', -1))}\nstride={int(blob.meta.get('stride', -1))}",
        transform=ax.transAxes,
        va="top",
        ha="right",
        fontsize=12,
        bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.95, edgecolor="#999999", linewidth=1.0),
        family="monospace",
    )

    ax.legend(loc="upper right", frameon=True, framealpha=0.95)
    _mkdir_parent(save_path)
    plt.tight_layout()
    plt.savefig(save_path, dpi=600, bbox_inches="tight", format="pdf")
    plt.close(fig)
    print(f"已保存图像: {save_path}")


def main() -> None:
    ap = argparse.ArgumentParser(description="Plot mean EOS token prob vs token index (teacher forcing) + cache")
    ap.add_argument("--model", type=str, required=False, default="/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-4B-Base", help="HF model path/name")
    ap.add_argument("--tokenizer", type=str, required=False, default=None, help="tokenizer path/name (default=model)")
    ap.add_argument("--inputs", type=str, nargs="+", required=False, default=None, help="one or more JSONL files")
    ap.add_argument("--name", type=str, default="run", help="plot title prefix")
    ap.add_argument("--max-samples", type=int, default=None, help="max samples per input (sequential)")
    ap.add_argument("--stride", type=int, default=1000, help="token index stride on x-axis")
    ap.add_argument("--max-resp-tokens", type=int, default=None, help="cap response token length (after prompt)")
    ap.add_argument(
        "--max-total-tokens",
        type=int,
        default=-1,
        help="keep only samples whose total tokens (prompt+response after max-resp-tokens) > this threshold. Set <=0 to disable.",
    )
    ap.add_argument("--batch-size", type=int, default=4, help="batch size for forward")
    ap.add_argument(
        "--metric",
        type=str,
        default="stop_at_t",
        choices=["stop_at_t", "stop_by_t"],
        help="stop_at_t: 恰好在t停止(PMF, 输出ln均值)；stop_by_t: 到t为止已停止(CDF, 输出概率均值)",
    )
    ap.add_argument("--amp", type=str, default="bf16", choices=["no", "fp16", "bf16"], help="autocast dtype")
    ap.add_argument("--trust-remote-code", action="store_true")

    ap.add_argument("--cache", type=str, default=None, help="npz cache path to save/load")
    ap.add_argument("--load-cache", action="store_true", help="load cache and plot directly")

    ap.add_argument("--out", type=str, default="eos_prob_curve.pdf", help="output pdf path")
    ap.add_argument("--log-y", action="store_true", help="use log y-axis")
    args = ap.parse_args()

    if args.load_cache:
        if not args.cache:
            raise SystemExit("--load-cache 需要同时提供 --cache")
        blob = load_cache(args.cache)
        title = f"{args.name}: mean EOS prob vs token index"
        plot_curve(blob, title=title, save_path=args.out, log_y=bool(args.log_y))
        return

    if not args.model or not args.inputs:
        raise SystemExit("需要提供 --model 和 --inputs（或使用 --load-cache）。")

    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer  # type: ignore
    except Exception as e:  # pragma: no cover
        raise SystemExit(
            "导入 transformers 失败（常见原因：huggingface_hub/transformers 版本不匹配）。"
            f"\n原始错误：{repr(e)}"
        ) from e

    tok_path = args.tokenizer or args.model
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=bool(args.trust_remote_code))
    print("Loading model (transformers forward for exact EOS prob)...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        trust_remote_code=bool(args.trust_remote_code),
        torch_dtype=torch.bfloat16 if args.amp == "bf16" else (torch.float16 if args.amp == "fp16" else None),
        device_map="auto",
    )

    # 读取数据（多文件合并到一个 run；如果你想每个文件单独画，可以多跑几次）
    prompts: List[Optional[str]] = []
    responses: List[str] = []
    for p in args.inputs:
        print(f"Loading data: {p}")
        data = load_data(p)
        if args.max_samples is not None:
            data = data[: int(args.max_samples)]
        for item in data:
            if not isinstance(item, dict):
                continue
            pr, rs = _extract_prompt_and_response(item, tokenizer)
            if rs is None:
                continue
            prompts.append(pr)
            responses.append(rs)
    prompts = prompts
    responses = responses
    if not responses:
        raise SystemExit("没有可用的样本（找不到 response 字段）。")

    max_total_tokens: Optional[int] = int(args.max_total_tokens) if args.max_total_tokens is not None else None
    if max_total_tokens is not None and max_total_tokens <= 0:
        max_total_tokens = None

    blob = _compute_eos_prob_curve(
        model=model,
        tokenizer=tokenizer,
        prompts=prompts,
        responses=responses,
        stride=int(args.stride),
        max_resp_tokens=args.max_resp_tokens,
        max_total_tokens=max_total_tokens,
        batch_size=int(args.batch_size),
        amp_dtype=args.amp,
        metric=str(args.metric),
    )
    blob.meta.update(
        {
            "model": args.model,
            "tokenizer": tok_path,
            "inputs": args.inputs,
            "max_samples": None if args.max_samples is None else int(args.max_samples),
        }
    )

    if args.cache:
        save_cache(args.cache, blob)
        print(f"已保存 cache: {args.cache}")

    title = f"{args.name}: mean EOS prob vs token index"
    plot_curve(blob, title=title, save_path=args.out, log_y=bool(args.log_y))


if __name__ == "__main__":
    main()