import os

# ================= 关键修改 1: CUDA 必须在 import vllm 之前设置 =================
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

import json
import argparse
import pandas as pd
from typing import Dict, List
from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer


UNIFIED_PROMPT_TEMPLATE = """
Think step by step, and answer the following question. 
Return the answer at the end of the response after a separator ####, e.g., #### 3/5.
Q: {question}
"""

def load_dataset_data(dataset_name: str) -> pd.DataFrame:
    """加载数据"""
    print(f"Loading data for: {dataset_name}...")
    df = pd.DataFrame()
    
    # 请确认路径无误
    if dataset_name == "math":
        df = pd.read_parquet("/home/-/datasets/MATH/train-00000-of-00001-7320a6f3aba8ebd2_5000.parquet")
        df["id"] = [f"math#{i}" for i in range(len(df))]
    elif dataset_name == "mmlupro":
        df = pd.read_parquet("/home/-/datasets/mmlupro/test-00000-of-00001_5000.parquet")
        df["id"] = [f"mmlupro#{i}" for i in range(len(df))]
    elif dataset_name == "bbh":
        df = pd.read_parquet("/home/-/datasets/bbh/bbh_all.parquet")
        df["id"] = [f"bbh#{i}" for i in range(len(df))]
    
    return df

def construct_prompt_text(row, tokenizer, task):
    """单条数据的 Prompt 构建函数"""
    if task == "mmlupro":
        q = row["question"]
        options = row["options"] if "options" in row else row["choices"]
        if not isinstance(options, list):
            try: options = list(options)
            except: pass
        
        # 处理选项格式
        if options:
            option_text = "\n".join([f"{chr(65+i)}. {opt}" for i, opt in enumerate(options)])
            question_text = f"{q}\n{option_text}"
        else:
            question_text = q
    else:
        question_text = row.get("question", row.get("problem", ""))

    user_content = UNIFIED_PROMPT_TEMPLATE.format(question=question_text)
    
    messages = [
        {"role": "user", "content": user_content}
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    prompt = prompt.rstrip()  # 去掉结尾多余空白，防止格式乱
    prompt += (
        "\n"
        "Okay, I think I have finished thinking.\n"
        "</think>\n"
    )

    return prompt
# ================= 主逻辑 =================

def main(args):
    print(f"Initializing vLLM with model: {args.model_path}")
    
    # 关键修改 2: 允许显示 log_stats，这样如果不动了你知道是不是在编译
    llm = LLM(
        model=args.model_path, 
        disable_log_stats=True,  
        max_model_len=20000,
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)

    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.8,
        top_k=20,
        max_tokens=8192,
    )

    os.makedirs(args.output_dir, exist_ok=True)
    
    # target_datasets = ["mmlupro"]
    # target_datasets = ["math"]
    # target_datasets = ["bbh"]
    target_datasets = [args.data_name]

    for ds_name in target_datasets:
        print(f"\n{'='*30}\nProcessing dataset: {ds_name}\n{'='*30}")
        
        # 1. 加载数据
        full_df = load_dataset_data(ds_name)
        if full_df.empty:
            print("DataFrame is empty, skipping...")
            continue
        
        # 2. 预处理 Prompt (显式 For 循环，不使用 apply)
        print("Formatting prompts (Pre-processing)...")
        chat_history_list = []
        
        # 这里用显式循环，方便调试，如果某一条报错可以直接定位
        for idx, row in tqdm(full_df.iterrows(), total=len(full_df), desc="Constructing Prompts"):
            try:
                prompt = construct_prompt_text(row, tokenizer, ds_name)
                chat_history_list.append(prompt)
            except Exception as e:
                print(f"Error formatting row {idx}: {e}")
                chat_history_list.append("") # 占位防错位

        full_df["chat_history"] = chat_history_list

        # 准备输出文件
        output_filename = os.path.join(args.output_dir, f"{ds_name}_result.jsonl")
        if os.path.exists(output_filename):
            print(f"Warning: Overwriting {output_filename}")
            os.remove(output_filename)

        # 3. 分块推理 (Batch Inference)
        batch_size = args.save_interval
        total_rows = len(full_df)
        num_batches = (total_rows + batch_size - 1) // batch_size

        print(f"Starting Inference: Total {total_rows} samples, {num_batches} batches.")

        # 外层循环：控制批次
        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, total_rows)
            
            # 切片
            batch_df = full_df.iloc[start_idx : end_idx].copy()
            prompts = batch_df["chat_history"].tolist()
            
            print(f"\n[Batch {batch_idx+1}/{num_batches}] Generating {len(prompts)} samples...")
            
            # 关键修改 3: use_tqdm=True，这样你知道 vLLM 内部跑到哪了
            outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
            
            # 显式提取 Output
            generated_texts = []
            generated_token_counts = []
            for output in outputs:
                completion = output.outputs[0]
                
                generated_texts.append(completion.text)
                
                generated_token_counts.append(len(completion.token_ids))
            
            batch_df["output"] = generated_texts
            batch_df["token_count"] = generated_token_counts
            
            # 保存
            batch_df.to_json(
                output_filename, 
                orient="records", 
                lines=True, 
                force_ascii=False, 
                mode='a' 
            )
            print(f"Batch {batch_idx+1} saved.")

    print("\nAll tasks completed!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/mnt/sharedata/", help="model path")
    parser.add_argument("--output_dir", type=str, default="./a-NoThinking", help="output directory")
    parser.add_argument("--data_name", type=str, default="bbh", help="dataset name")
    parser.add_argument("--save_interval", type=int, default=100, help="save interval")
    
    args = parser.parse_args()
    main(args)