import argparse
import csv
import math
import os


def _mkdir_parent(path: str) -> None:
    if path is None:
        return
    parent = os.path.dirname(path)
    if parent:
        os.makedirs(parent, exist_ok=True)


def _pretty(tok: str) -> str:
    return tok.replace("\n", "\\n").replace("\t", "\\t")


def _mean_by_position_from_vllm(outs):
    """
    outs: vLLM 的 RequestOutput list
    统计口径：对每个生成位置 t，取 chosen token 的 logprob/prob，
    对所有“生成长度 >= t+1”的样本求均值。
    返回：mean_prob, mean_logprob, counts
    """
    sum_prob = []
    sum_lp = []
    cnt = []

    for ro in outs:
        if not ro.outputs:
            continue
        o = ro.outputs[0]
        if o.logprobs is None:
            raise RuntimeError("没有拿到 logprobs（无法统计）。请确认 SamplingParams(logprobs>0) 且 vLLM 版本支持。")

        token_ids = list(o.token_ids)
        n = min(len(token_ids), len(o.logprobs))
        for i in range(n):
            d = o.logprobs[i]  # dict[token_id] -> Logprob
            tid = token_ids[i]
            lp_obj = d.get(tid, next(iter(d.values())))
            lp = float(lp_obj["logprob"]) if isinstance(lp_obj, dict) else float(lp_obj.logprob)
            prob = math.exp(lp) if lp > -745 else 0.0

            if i >= len(sum_prob):
                sum_prob.append(0.0)
                sum_lp.append(0.0)
                cnt.append(0)
            sum_prob[i] += prob
            sum_lp[i] += lp
            cnt[i] += 1

    mean_prob = [(sum_prob[i] / cnt[i]) if cnt[i] else float("nan") for i in range(len(sum_prob))]
    mean_lp = [(sum_lp[i] / cnt[i]) if cnt[i] else float("nan") for i in range(len(sum_lp))]
    return mean_prob, mean_lp, cnt


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str, default="/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/baseline-8k-minibsz32-dapo-math/best_model/actor/huggingface")
    # ap.add_argument("--model", type=str, default="/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-4B-Base")
    ap.add_argument("--question", type=str, default="In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$ be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$ and $\\angle BDC = 90^\\circ$. Suppose that $AD = 1$ and that $\\frac{BD}{CD} = \\frac{3}{2}$. If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$. Let's think step by step and output the final answer within \\boxed{}.")
    
    # parquet 统计模式
    ap.add_argument("--parquet", type=str, default=None, help="parquet 文件路径；提供后统计每个位置的平均 prob")
    ap.add_argument("--prompt_col", type=str, default="prompt", help="parquet 中 prompt 列名（默认 prompt）")
    ap.add_argument("--n_samples", type=int, default=1000, help="按顺序取的条数（默认 1000）")
    ap.add_argument("--start_idx", type=int, default=0, help="按顺序取的起始下标（默认 0）")
    ap.add_argument("--remove_system", action="store_true", help="如果 prompt 是 messages 且首条是 system，则移除")

    ap.add_argument("--max_tokens", type=int, default=16000)
    ap.add_argument("--temperature", type=float, default=0.6)
    ap.add_argument("--top_p", type=float, default=1.0)
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--logprobs_k", type=int, default=5, help="vLLM SamplingParams.logprobs")

    ap.add_argument("--trust_remote_code", action="store_true")
    ap.add_argument("--out_png", type=str, default="/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/plots/token_prob_baseline.png")
    ap.add_argument("--out_csv", type=str, default="/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/plots/token_prob_baseline.csv")
    args = ap.parse_args()

    # 初始化 tokenizer + vLLM
    from transformers import AutoTokenizer
    import torch
    from vllm import LLM, SamplingParams

    tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=bool(args.trust_remote_code))
    tp = int(torch.cuda.device_count() or 1)
    llm = LLM(model=args.model, tensor_parallel_size=tp, trust_remote_code=bool(args.trust_remote_code))
    sp = SamplingParams(
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        seed=args.seed,
        logprobs=args.logprobs_k,
    )

    # ================= parquet 批量统计 =================
    if args.parquet:
        import pandas as pd
        import matplotlib
        matplotlib.use("Agg", force=True)
        import matplotlib.pyplot as plt

        df = pd.read_parquet(args.parquet)
        

        start = max(0, int(args.start_idx))
        n = min(int(args.n_samples), len(df) - start)
        df_s = df.iloc[start : start + n]
        raw = df_s[args.prompt_col].tolist()

        prompts = []
        for p in raw:
            
            messages = p
            prompts.append(tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True))
          

        print(
            f"=== parquet 统计（按顺序切片） ===\n"
            f"file: {args.parquet}\n"
            f"col: {args.prompt_col}\n"
            f"start_idx: {start}\n"
            f"n_samples: {len(prompts)}\n"
        )
        outs = llm.generate(prompts, sampling_params=sp)
        mean_prob, mean_lp, cnt = _mean_by_position_from_vllm(outs)

        # CSV: 每个位置的均值
        _mkdir_parent(args.out_csv)
        with open(args.out_csv, "w", encoding="utf-8", newline="") as f:
            w = csv.writer(f)
            w.writerow(["position", "mean_prob", "mean_logprob", "count"])
            for i in range(len(mean_prob)):
                w.writerow([i, mean_prob[i], mean_lp[i], cnt[i]])
        print("csv:", args.out_csv)

        # 图：只画 mean_prob（你只要 prob）；logprob 在 CSV 里保留
        xs = list(range(len(mean_prob)))
        fig, ax = plt.subplots(figsize=(14, 5))
        ax.plot(xs, mean_prob, linewidth=1.8, label="mean_prob")
        ax.set_xlabel("token position (generated)")
        ax.set_ylabel("mean prob")
        ax.grid(True, alpha=0.25)
        ax.legend(loc="upper right")
        _mkdir_parent(args.out_png)
        plt.tight_layout()
        plt.savefig(args.out_png, dpi=220, bbox_inches="tight")
        plt.close(fig)
        print("png:", args.out_png)
        return

    # ================= 单条分析（保留） =================
    # 1) 组 prompt
    if not hasattr(tok, "apply_chat_template"):
        raise SystemExit("该 tokenizer 不支持 apply_chat_template；请改用 --prompt 传入完整 prompt。")
    prompt = tok.apply_chat_template(
            [{"role": "user", "content": args.question}],
            tokenize=False,
            add_generation_prompt=True,
        )
    
    # 2) vLLM 生成 + 拿 logprobs
    out = llm.generate([prompt], sampling_params=sp)[0].outputs[0]

    print("\n=== 模型回答 ===\n")
    print(out.text)
    token_ids = list(out.token_ids)
    tokens = getattr(out, "tokens", None)

    steps = []
    for i, tid in enumerate(token_ids):
        d = out.logprobs[i]  # dict[token_id] -> Logprob
        lp_obj = d.get(tid, next(iter(d.values())))
        lp = float(lp_obj["logprob"]) if isinstance(lp_obj, dict) else float(lp_obj.logprob)
        prob = math.exp(lp) if lp > -745 else 0.0
        tok_text = None
        if tokens is not None and i < len(tokens):
            tok_text = tokens[i]
        if tok_text is None:
            decoded = lp_obj.get("decoded_token") if isinstance(lp_obj, dict) else getattr(lp_obj, "decoded_token", None)
            tok_text = decoded if decoded is not None else f"<tok:{tid}>"
        steps.append((i, int(tid), str(tok_text), lp, prob))


    
    _mkdir_parent(args.out_csv)
    with open(args.out_csv, "w", encoding="utf-8", newline="") as f:
        w = csv.writer(f)
        w.writerow(["step", "token_id", "token", "logprob", "prob"])
        w.writerows(steps)
    print("csv:", args.out_csv)

    # 5) 画图
    import matplotlib

    matplotlib.use("Agg", force=True)
    import matplotlib.pyplot as plt

    xs = [s[0] for s in steps]
    lps = [s[3] for s in steps]
    ps = [s[4] for s in steps]

    fig, ax1 = plt.subplots(figsize=(14, 5))
    ax1.plot(xs, ps, linewidth=1.8, label="prob")
    ax1.set_xlabel("token index")
    ax1.set_ylabel("prob")
    ax1.grid(True, alpha=0.25)

    ax2 = ax1.twinx()
    ax2.plot(xs, lps, linewidth=1.2, linestyle="--", color="tab:orange", label="logprob")
    ax2.set_ylabel("logprob")

    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
    _mkdir_parent(args.out_png)
    plt.tight_layout()
    plt.savefig(args.out_png, dpi=220, bbox_inches="tight")
    plt.close(fig)
    print("png:", args.out_png)


if __name__ == "__main__":
    main()


