from datasets import load_from_disk, concatenate_datasets
from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
from tqdm import tqdm
import json
import os
from pathlib import Path

def ensure_file_directory_exists(file_path):
    # Convert to Path object for better path handling
    path = Path(file_path)
    
    # Get the parent directory
    dir_path = path.parent
    
    # Create directory if it doesn't exist
    if not dir_path.exists():
        dir_path.mkdir(parents=True, exist_ok=True)
        print(f"Created directory: {dir_path}")
    
    return str(path.resolve())

def get_eval_prompt(question, tokenizer):
    # prompt = question + "\n\n You need to break down the above problem into sub-questions that can help derive the final answer. For each sub-question generated, answer it first, then continue generating further sub-questions until the final answer is reached."
    # system_prompt = "You are a math expert."
    prompt = question
    system_prompt = ""

    message = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    return tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt=True
    )

def main(model, dataset, output_dir, task):
    tokenizer = AutoTokenizer.from_pretrained(model)  
    tokenizer.pad_token = tokenizer.eos_token

    llm = LLM(
        model=model,
        tensor_parallel_size=2,
        max_model_len=4096
    )

    greedy_params = SamplingParams(temperature=0, max_tokens=512, stop=["<|endoftext|>", "<|im_end|>"])
    prompts_all = [get_eval_prompt(dataset['problem'][idx], tokenizer) for idx in range(len(dataset))]

    greedy_gathered = list(map(lambda x: x.outputs[0].text, 
                    llm.generate(prompts_all, greedy_params)))
    
    sampling_params = SamplingParams(temperature=0.8, max_tokens=512, stop=["<|endoftext|>", "<|im_end|>"])

    samples_num = 50
    prompts_all_samples = prompts_all * samples_num
    samples_gathered = list(map(lambda x: x.outputs[0].text, 
                    llm.generate(prompts_all_samples, sampling_params)))
    
    result = []
    for idx, example in enumerate(dataset):
        samples = []
        for i in range(samples_num):
            sample_idx = idx + i * len(dataset)
            samples.append(samples_gathered[sample_idx])
        
        result_i = {
            "question": example['problem'],
            "solution": example['solution'],
            "answer": example['answer'],
            "greedy_sample": greedy_gathered[idx],
            "samples": samples,
            "label": example['label']
        }
        result.append(result_i)
    
    model_name = os.path.basename(model)
    output_path = os.path.join(output_dir, model_name + f"_{task}.json")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(result, f, ensure_ascii=False, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="自动化代码补全处理脚本")
    parser.add_argument("--model", type=str, required=True, help="使用的模型名称")
    parser.add_argument("--inference_dataset_path", type=str, required=True, help="数据集存储路径")
    parser.add_argument("--output_dir", type=str, required=True, help="结果输出目录")
    parser.add_argument("--task", type=str, required=True, help="")
    args = parser.parse_args()

    # 加载数据集
    dataset = load_from_disk(args.inference_dataset_path)
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 执行处理流程
    main(args.model, dataset, args.output_dir, args.task)