# filename: compute_per_token_logps.py
import argparse
import json
from pathlib import Path
from typing import List, Dict, Any, Tuple
from tqdm import tqdm
import torch
import torch.nn.functional as F
import accelerate
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from utils import save_dataset, load_single_dataset
import datasets


distributed_state = Accelerator()

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model-path",     type=str, required=True,  help="HF保存的模型本地/远端路径，比如 xxx")
    ap.add_argument("--tokenizer-path", type=str, required=True,  help="HF保存的tokenizer本地/远端路径，比如 xxx")
    ap.add_argument("--debug-mode",     type=str, required=False, default=None)
    ap.add_argument("--input-json",     type=str, required=True,  help="输入json路径")
    ap.add_argument("--model_type",     type=str, required=False, default="rm1", help="rm1是语言模型，搜集语言模型在input-ids上的logp；rm2是ex-rm，搜集rm在eos token上的打分；rm3是类似于qrm的，搜集在reponse位置上的打分")
    ap.add_argument("--output-path",    type=str, required=True,  help="输出json路径")
    ap.add_argument("--batch-size",     type=int, default=4)
    ap.add_argument("--model-tag",      type=str, required=True,  help="写入字典的键名前缀，例如 qwen2_7b")
    ap.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["auto", "bfloat16", "float16", "float32"],
        help="模型dtype；建议auto或bfloat16"
    )
    return ap.parse_args()


def to_dtype(name: str):
    if name == "auto":
        return "auto"
    return getattr(torch, name)


def build_tasks(data: List[Dict[str, Any]]) -> List[Tuple[int, int]]:
    """
    生成 (sample_idx, response_idx) 的线性任务列表
    """
    tasks = []
    for i, item in enumerate(data):
        for j, _ in enumerate(item.get("responses", [])):
            tasks.append((i, j))
    return tasks


def encode_one_pair(tokenizer, prompt_messages, response_text):
    """
    返回：
      full_ids: prompt+assistant(response) 的 input_ids (1D LongTensor)
      prompt_ids: 仅 prompt (并带生成前缀) 的 input_ids (1D LongTensor)，用于确定response起点
    """
    # prompt 加上生成前缀（多数chat模板会在这里插入assistant起始标记）
    prompt_enc = tokenizer.apply_chat_template(
        prompt_messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    )
    # 完整（prompt + assistant响应）
    full_enc = tokenizer.apply_chat_template(
        prompt_messages + [{"role": "assistant", "content": response_text}],
        add_generation_prompt=False,
        tokenize=True,
        return_tensors="pt",
    )
    prompt_ids = prompt_enc[0]
    full_ids = full_enc[0]
    return full_ids, prompt_ids


def collate_right_pad(batch_ids: List[torch.Tensor], pad_id: int):
    """
    Right-pad 到同一长度，返回 input_ids [B,L]、attention_mask [B,L]、lengths [B]
    """
    lengths = torch.tensor([t.size(0) for t in batch_ids], dtype=torch.long)
    max_len = int(lengths.max().item())
    bsz = len(batch_ids)
    input_ids = torch.full((bsz, max_len), pad_id, dtype=torch.long)
    attn_mask = torch.zeros((bsz, max_len), dtype=torch.long)
    for i, ids in enumerate(batch_ids):
        L = ids.size(0)
        input_ids[i, :L] = ids
        attn_mask[i, :L] = 1
    return input_ids, attn_mask, lengths


@torch.no_grad()
def compute_batch_logps(
    model_type,
    model,
    tokenizer,
    full_ids_list: List[torch.Tensor],
    prompt_ids_list: List[torch.Tensor],
    device: torch.device,
) -> List[List[float]]:
    """
    对一个batch计算每个响应token的logp（不包含padding）。
    按“句子逐个 for 循环取logp”的方式取出，避免峰值显存上涨。
    """
    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        # 兜底：绝大多数LLM的pad会等于eos
        pad_id = tokenizer.eos_token_id

    # right-pad
    input_ids, attention_mask, lengths = collate_right_pad(full_ids_list, pad_id)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    results = []
    if model_type == "rm1":
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [B, L, V]
        for b in range(input_ids.size(0)):
            L = int(lengths[b].item())
            # response 起点
            resp_start = prompt_ids_list[b].size(0)  # 第一个响应token在 full_ids 的下标
            # logits[b, :L-1] 预测 input_ids[b, 1:L]
            log_probs = F.log_softmax(logits[b, : L - 1, :].float(), dim=-1)  # [L-1, V]
            next_token_ids = input_ids[b, 1:L]  # [L-1]
            # 拿到所有位置的logp
            token_logps_full = log_probs[torch.arange(L - 1), next_token_ids]  # [L-1]
            # 响应token对应的区间：从 resp_start 到 L-1 的token，其logp位于 [resp_start-1 : L-1)
            start_idx = resp_start - 1
            if start_idx < 0:
                start_idx = 0
            resp_token_logps = token_logps_full[start_idx : L - 1]
            results.append(resp_token_logps.detach().cpu().numpy().tolist())

    elif model_type == "rm2":
        _, _, values = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, use_cache=False, pad_id=pad_id)
        scores = values.gather(dim=-1, index=(attention_mask.sum(dim=-1, keepdim=True) - 1))
        scores = scores.squeeze(-1)
        results.extend(scores.detach().cpu().numpy().tolist())

    elif model_type == "rm3":
        _, _, values = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, use_cache=False)
        scores = (values * attention_mask).sum(-1) / attention_mask.sum(-1).clamp_min(1.0)
        results.append(scores.detach().cpu().numpy().tolist())

    return results


def main():
    args = parse_args()
    # 读取数据（只在rank0读，然后广播/拆分更复杂；这里直接各rank都读本地文件，简单可行）
    data: datasets.Dataset = load_single_dataset(args.input_json)
    if args.debug_mode:
        data = data.select(range(10))
    data = data.to_list()

    # 构建任务列表并在进程间拆分
    tasks_all = build_tasks(data)
    # 使用 split_between_processes 自动划分任务
    with distributed_state.split_between_processes(tasks_all) as local_tasks:
        # 准备模型与分词器
        dtype = to_dtype(args.dtype)
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
        if tokenizer.pad_token is None:
            # 绝大多数causalLM把pad设为eos较为稳妥
            tokenizer.pad_token = tokenizer.eos_token

        if args.model_type == "rm1":
            model = AutoModelForCausalLM.from_pretrained(
                args.model_path,
                torch_dtype=(dtype if dtype != "auto" else None),
                device_map=None,
            )
        else:
            model = AutoModelForCausalLMWithValueHead.from_pretrained(
                args.model_path,
                torch_dtype=(dtype if dtype != "auto" else None),
                device_map=None,
            )
        model.eval()
        model.to(distributed_state.device)
        tokenizer.padding_side = "right"

        # 逐batch跑
        local_results: List[Tuple[int, int, List[float]]] = []
        bs = max(1, int(args.batch_size))

        for i in tqdm(range(0, len(local_tasks), bs)):
            chunk = local_tasks[i : i + bs]

            # 先把本batch的输入都编码好
            full_ids_list = []
            prompt_ids_list = []
            for (sample_idx, resp_idx) in chunk:
                item = data[sample_idx]
                prompt_messages = item["prompt"]  # 已经是 role/content 的结构
                response_text = item["responses"][resp_idx]
                full_ids, prompt_ids = encode_one_pair(tokenizer, prompt_messages, response_text)
                full_ids_list.append(full_ids)
                prompt_ids_list.append(prompt_ids)

            # 前向一次，随后“每句话每句话for循环”取logp
            batch_logps = compute_batch_logps(
                model_type=args.model_type,
                model=model,
                tokenizer=tokenizer,
                full_ids_list=full_ids_list,
                prompt_ids_list=prompt_ids_list,
                device=distributed_state.device,
            )

            for (sample_idx, resp_idx), lp in zip(chunk, batch_logps):
                local_results.append((sample_idx, resp_idx, lp))

    # 聚合所有进程的结果
    local_results = [local_results]
    gathered: List[List[Tuple[int, int, List[float]]]] = accelerate.utils.gather_object(local_results)
    if distributed_state.is_main_process:
        flat: List[Tuple[int, int, List[float]]] = []
        for part in gathered:
            flat.extend(part)

        # 将结果写入到对应的样本结构里
        key_name = f"{args.model_tag}"
        # 先为每个样本预建空列表（与 responses 对齐）
        for item in data:
            n = len(item.get("responses", []))
            item[key_name] = [[] for _ in range(n)]

        for sample_idx, resp_idx, lp in flat:
            data[sample_idx][key_name][resp_idx] = lp

        # 保存到输出路径
        Path(args.output_path).parent.mkdir(parents=True, exist_ok=True)
        save_dataset(data, args.output_path)

    # 保证所有进程在退出前同步
    distributed_state.wait_for_everyone()


if __name__ == "__main__":
    main()


"""

# REF
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
~/verl_cs/.conda/bin/accelerate launch \
  --num_processes 8 ~/verl_cs/scripts/bon2_compute_logp_for_responses.py \
  --model-path      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft \
  --tokenizer-path  ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft \
  --input-json      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_scored.json \
  --output-path     ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_ref.json \
  --batch-size 2 \
  --model-tag reflogp

# DPO
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
~/verl_cs/.conda/bin/accelerate launch \
  --num_processes 8 ~/verl_cs/scripts/bon2_compute_logp_for_responses.py \
  --model-path      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/dpo_rm \
  --tokenizer-path  ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft \
  --input-json      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_ref.json \
  --output-path     ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_dpo.json \
  --batch-size 2 \
  --model-tag rmlogp

# IPVRM
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
~/verl_cs/.conda/bin/accelerate launch \
  --num_processes 8 ~/verl_cs/scripts/bon2_compute_logp_for_responses.py \
  --model-path      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/implicit-drm-beta5-gamma2.5 \
  --tokenizer-path  ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft \
  --input-json      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_ref.json \
  --output-path     ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_ipvrm.json \
  --batch-size 2 \
  --model-tag rmlogp

# ImplicitPRM
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
~/verl_cs/.conda/bin/accelerate launch \
  --num_processes 8 ~/verl_cs/scripts/bon2_compute_logp_for_responses.py \
  --model-path      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/implicit-prm \
  --tokenizer-path  ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft \
  --input-json      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_ref.json \
  --output-path     ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_implicitprm.jsonn \
  --batch-size 2 \
  --model-tag rmlogp

# QRM
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
~/verl_cs/.conda/bin/accelerate launch \
  --num_processes 8 ~/verl_cs/scripts/bon2_compute_logp_for_responses.py \
  --model-path      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/qrm \
  --tokenizer-path  ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft \
  --input-json      ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_scored.json \
  --output-path     ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft/prime-rl-rollouts/bon_test_0_842_qrm.json \
  --batch-size 1 \
  --model-tag rmlogp \
  --model_type rm3


  

"""