from datasets import load_from_disk, concatenate_datasets
from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
from tqdm import tqdm
import json
import os
from pathlib import Path

def ensure_file_directory_exists(file_path):
    # Convert to Path object for better path handling
    path = Path(file_path)
    
    # Get the parent directory
    dir_path = path.parent
    
    # Create directory if it doesn't exist
    if not dir_path.exists():
        dir_path.mkdir(parents=True, exist_ok=True)
        print(f"Created directory: {dir_path}")
    
    return str(path.resolve())

def get_eval_prompt(instance, tokenizer):
    # prompt = question + "\n\n You need to break down the above problem into sub-questions that can help derive the final answer. For each sub-question generated, answer it first, then continue generating further sub-questions until the final answer is reached."
    system_prompt = """
    You are a multiple-choice answering agent for scientific and technical exam questions.
Follow these rules strictly:

Read the question and options carefully.

Choose exactly ONE option.

When giving the answer, also explain clearly why this option is correct (and why others are less suitable if relevant).
"""
    prompt = f"[question]: {instance['question']} \n [options]: {instance['options']}"

    message = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    return tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt=True
    )

def main(args, model, dataset, output_dir, task):
    tokenizer = AutoTokenizer.from_pretrained(model)  
    tokenizer.pad_token = tokenizer.eos_token

    llm = LLM(
        model=model,
        tensor_parallel_size=args.tensor_parallel_size,
        max_model_len=4096,
        gpu_memory_utilization=args.gpu_memory_utilization
    )

    greedy_params = SamplingParams(temperature=0, max_tokens=512, stop=["<|endoftext|>", "<|im_end|>"])
    prompts_all = [get_eval_prompt(dataset[idx], tokenizer) for idx in range(len(dataset))]

    greedy_gathered = list(map(lambda x: x.outputs[0].text, 
                    llm.generate(prompts_all, greedy_params)))
    
    sampling_params = SamplingParams(temperature=0.8, max_tokens=512, stop=["<|endoftext|>", "<|im_end|>"])

    samples_num = 50
    prompts_all_samples = prompts_all * samples_num
    samples_gathered = list(map(lambda x: x.outputs[0].text, 
                    llm.generate(prompts_all_samples, sampling_params)))
    
    result = []
    for idx, example in enumerate(dataset):
        samples = []
        for i in range(samples_num):
            sample_idx = idx + i * len(dataset)
            samples.append(samples_gathered[sample_idx])
        
        result_i = {
            "question": example['question'],
            "options": example['options'],
            "solution": example['explanation'],
            "answer": example['answer'],
            "greedy_sample": greedy_gathered[idx],
            "samples": samples,
            "label": example['label']
        }
        result.append(result_i)
    
    model_name = os.path.basename(model)
    output_path = os.path.join(output_dir, model_name + f"_{task}.json")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(result, f, ensure_ascii=False, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="自动化代码补全处理脚本")
    parser.add_argument("--model", type=str, required=True, help="使用的模型名称")
    parser.add_argument("--inference_dataset_path", type=str, required=True, help="数据集存储路径")
    parser.add_argument("--output_dir", type=str, required=True, help="结果输出目录")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.4)
    parser.add_argument("--tensor_parallel_size", type=int, default=2)
    parser.add_argument("--task", type=str, required=True, help="")
    args = parser.parse_args()

    # 加载数据集
    # dataset = load_from_disk(args.inference_dataset_path)
    with open(args.inference_dataset_path, "r") as f:
        dataset = json.load(f)
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 执行处理流程
    main(args, args.model, dataset, args.output_dir, args.task)