import os
import argparse
from pathlib import Path
from typing import List, Dict, Any, Tuple
from tqdm import tqdm
import torch
import torch.nn.functional as F
import accelerate
from accelerate import Accelerator, InitProcessGroupKwargs
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from utils import save_dataset, load_single_dataset
import datasets
from datetime import timedelta

init_proc = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
distributed_state = Accelerator(kwargs_handlers=[init_proc])

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model-path",     type=str, required=True,  help="HF保存的模型本地/远端路径，比如 xxx")
    ap.add_argument("--tokenizer-path", type=str, required=True,  help="HF保存的tokenizer本地/远端路径，比如 xxx")
    ap.add_argument("--debug-mode",     type=str, required=False, default=None)
    ap.add_argument("--input-json",     type=str, required=True,  help="输入json路径")
    ap.add_argument("--model_type",     type=str, required=False, default="rm1", help="rm1是语言模型，搜集语言模型在input-ids上的logp；rm2是ex-rm，搜集rm在eos token上的打分；rm3是类似于qrm的，搜集在reponse位置上的打分")
    ap.add_argument("--output-path",    type=str, required=True,  help="输出json路径")
    ap.add_argument("--batch-size",     type=int, default=4)
    ap.add_argument("--model-tag",      type=str, required=True,  help="写入字典的键名前缀，例如 qwen2_7b")
    ap.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["auto", "bfloat16", "float16", "float32"],
        help="模型dtype；建议auto或bfloat16"
    )
    ap.add_argument("--only_save_necessary_columns", type=bool, required=False, default=True, help="是否仅仅保留最少字段")
    return ap.parse_args()


def to_dtype(name: str):
    if name == "auto":
        return "auto"
    return getattr(torch, name)


def build_tasks(data: List[Dict[str, Any]]) -> List[Tuple[int, int]]:
    """
    生成 (sample_idx, response_idx) 的线性任务列表
    """
    tasks = []
    for i, item in enumerate(data):
        for j, _ in enumerate(item.get("responses", [])):
            tasks.append((i, j))
    return tasks


def encode_one_pair(tokenizer, prompt_messages, response_text):
    """
    返回：
      full_ids: prompt+assistant(response) 的 input_ids (1D LongTensor)
      prompt_ids: 仅 prompt (并带生成前缀) 的 input_ids (1D LongTensor)，用于确定response起点
    """
    # prompt 加上生成前缀（多数chat模板会在这里插入assistant起始标记）
    prompt_enc = tokenizer.apply_chat_template(
        prompt_messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    )
    # 完整（prompt + assistant响应）
    full_enc = tokenizer.apply_chat_template(
        prompt_messages + [{"role": "assistant", "content": response_text}],
        add_generation_prompt=False,
        tokenize=True,
        return_tensors="pt",
    )
    prompt_ids = prompt_enc[0]
    full_ids = full_enc[0]
    return full_ids, prompt_ids


def collate_right_pad(batch_ids: List[torch.Tensor], pad_id: int):
    """
    Right-pad 到同一长度，返回 input_ids [B,L]、attention_mask [B,L]、lengths [B]
    """
    lengths = torch.tensor([t.size(0) for t in batch_ids], dtype=torch.long)
    max_len = int(lengths.max().item())
    bsz = len(batch_ids)
    input_ids = torch.full((bsz, max_len), pad_id, dtype=torch.long)
    attn_mask = torch.zeros((bsz, max_len), dtype=torch.long)
    for i, ids in enumerate(batch_ids):
        L = ids.size(0)
        input_ids[i, :L] = ids
        attn_mask[i, :L] = 1
    return input_ids, attn_mask, lengths


@torch.no_grad()
def compute_batch_logps(
    model_type,
    model,
    tokenizer,
    full_ids_list: List[torch.Tensor],
    prompt_ids_list: List[torch.Tensor],
    device: torch.device,
) -> List[List[float]]:
    """
    对一个batch计算每个响应token的logp（不包含padding）。
    按“句子逐个 for 循环取logp”的方式取出，避免峰值显存上涨。
    """
    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        # 兜底：绝大多数LLM的pad会等于eos
        pad_id = tokenizer.eos_token_id

    # right-pad
    input_ids, attention_mask, lengths = collate_right_pad(full_ids_list, pad_id)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    results = []
    if model_type == "rm1":
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [B, L, V]
        for b in range(input_ids.size(0)):
            L = int(lengths[b].item())
            # response 起点
            resp_start = prompt_ids_list[b].size(0)  # 第一个响应token在 full_ids 的下标
            # logits[b, :L-1] 预测 input_ids[b, 1:L]
            log_probs = F.log_softmax(logits[b, : L - 1, :].float(), dim=-1)  # [L-1, V]
            next_token_ids = input_ids[b, 1:L]  # [L-1]
            # 拿到所有位置的logp
            token_logps_full = log_probs[torch.arange(L - 1), next_token_ids]  # [L-1]
            # 响应token对应的区间：从 resp_start 到 L-1 的token，其logp位于 [resp_start-1 : L-1)
            start_idx = resp_start - 1
            if start_idx < 0:
                start_idx = 0
            resp_token_logps = token_logps_full[start_idx : L - 1]
            results.append([float(resp_token_logps.sum()), float(resp_token_logps.mean())])

    elif model_type == "rm2":
        _, _, values = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, use_cache=False, pad_id=pad_id)
        scores = values.gather(dim=-1, index=(attention_mask.sum(dim=-1, keepdim=True) - 1))
        scores = scores.squeeze(-1)
        results.extend(scores.detach().cpu().numpy().tolist())

    elif model_type == "rm3":
        _, _, values = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, use_cache=False)
        scores_avg = (values * attention_mask).sum(-1) / attention_mask.sum(-1).clamp_min(1.0)
        scores_sum = (values * attention_mask).sum(-1)
        results.extend([[float(scores_sum[b]), float(scores_avg[b])] for b in range(input_ids.size(0))])
    
    return results


def main():
    args = parse_args()
    # 读取数据（只在rank0读，然后广播/拆分更复杂；这里直接各rank都读本地文件，简单可行）
    data: datasets.Dataset = load_single_dataset(args.input_json)
    if args.debug_mode:
        data = data.select(range(10))
    data = data.to_list()
    use_8bit = "qwen3-8B-base" in args.model_path

    # 构建任务列表并在进程间拆分
    tasks_all = build_tasks(data)
    # 使用 split_between_processes 自动划分任务
    with distributed_state.split_between_processes(tasks_all) as local_tasks:
        # 准备模型与分词器
        dtype = to_dtype(args.dtype)
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True)
        if tokenizer.pad_token is None:
            # 绝大多数causalLM把pad设为eos较为稳妥
            tokenizer.pad_token = tokenizer.eos_token

        if args.model_type == "rm1":
            model = AutoModelForCausalLM.from_pretrained(
                args.model_path,
                load_in_8bit=use_8bit,
                torch_dtype=(dtype if dtype != "auto" else None),
                device_map=None,
            )
        else:
            model = AutoModelForCausalLMWithValueHead.from_pretrained(
                args.model_path,
                load_in_8bit=use_8bit,
                torch_dtype=(dtype if dtype != "auto" else None),
                device_map=None,
            )
            from safetensors.torch import load_file
            try:
                value_head_sd: dict = load_file(os.path.join(args.model_path,"value_head.safetensors"), device="cuda")
            except FileNotFoundError:
                value_head_sd: dict = torch.load(os.path.join(args.model_path,"value_head.bin"), map_location="cuda")
            value_head_sd['summary.bias'] = value_head_sd.pop('v_head.summary.bias')
            value_head_sd['summary.weight'] = value_head_sd.pop('v_head.summary.weight')
            model.v_head.load_state_dict(value_head_sd)


        model.eval()
        if not use_8bit:
            model.to(distributed_state.device)
        tokenizer.padding_side = "right"

        # 逐batch跑
        local_results: List[Tuple[int, int, List[float]]] = []
        bs = max(1, int(args.batch_size))

        for i in tqdm(range(0, len(local_tasks), bs)):
            chunk = local_tasks[i : i + bs]

            # 先把本batch的输入都编码好
            full_ids_list = []
            prompt_ids_list = []
            for (sample_idx, resp_idx) in chunk:
                item = data[sample_idx]
                prompt_messages = item["prompt"]  # 已经是 role/content 的结构
                response_text = item["responses"][resp_idx]
                full_ids, prompt_ids = encode_one_pair(tokenizer, prompt_messages, response_text)
                full_ids_list.append(full_ids)
                prompt_ids_list.append(prompt_ids)

            # 前向一次，随后“每句话每句话for循环”取logp
            batch_logps = compute_batch_logps(
                model_type=args.model_type,
                model=model,
                tokenizer=tokenizer,
                full_ids_list=full_ids_list,
                prompt_ids_list=prompt_ids_list,
                device=distributed_state.device,
            )

            for (sample_idx, resp_idx), lp in zip(chunk, batch_logps):
                local_results.append((sample_idx, resp_idx, lp))

    # 聚合所有进程的结果
    local_results = [local_results]
    gathered: List[List[Tuple[int, int, List[float]]]] = accelerate.utils.gather_object(local_results)
    if distributed_state.is_main_process:
        flat: List[Tuple[int, int, List[float]]] = []
        for part in gathered:
            flat.extend(part)

        # 将结果写入到对应的样本结构里
        key_name = f"{args.model_tag}"
        # 先为每个样本预建空列表（与 responses 对齐）
        for item in data:
            n = len(item.get("responses", []))
            item[key_name] = [[] for _ in range(n)]

        for sample_idx, resp_idx, lp in flat:
            data[sample_idx][key_name][resp_idx] = lp

        # 保存到输出路径
        Path(args.output_path).parent.mkdir(parents=True, exist_ok=True)

        # save_dataset
        data = datasets.Dataset.from_list(data)
        only_save_necessary_columns = False
        if args.only_save_necessary_columns and "ref" not in args.model_tag:
            only_save_necessary_columns = True
        if only_save_necessary_columns:
            unnecessary_keys = [k for k in data.column_names if k not in ["reflogp", args.model_tag, "data_source", "scores"]]
            data = data.remove_columns(unnecessary_keys)
        save_dataset(data, args.output_path)

    # 保证所有进程在退出前同步
    distributed_state.wait_for_everyone()


if __name__ == "__main__":
    main()
    