import os
import argparse
import json
import random
import multiprocessing as mp
from tqdm import tqdm
from datasets import load_from_disk, load_dataset
from evaluator.MATH_evaluator_list import MATHEvaluator
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import shutil


# ÉèÖÃËæ»úÖÖ×Ó
random.seed(42)

# Êý¾Ý¼¯ÓëÆÀ¹ÀÆ÷Ó³Éä
name2path = {
    "GSM8K": "./dataset/gsm8k.jsonl",
    "MATH": "./dataset/math.jsonl",
    "MATH_OAI": "./dataset/MATH_OAI.jsonl",
    "AIME24": "./dataset/AIME24.jsonl",
    "AIME25": "./dataset/AIME25.jsonl",
    "OlympiadBench": "./dataset/OlympiadBench.jsonl",
    "minerva": "./dataset/minerva.jsonl",
    "OMNI": "./dataset/omni_math_num_500.jsonl",
    "LiveAOPSbench": "./dataset/liveaopsbench-2024-8-2024-12-num.jsonl",
    "TheoremQA": "./dataset/TheoremQA.jsonl",
}

name2eval = {
    "GSM8K": MATHEvaluator(),
    "MATH": MATHEvaluator(),
    "MATH_OAI": MATHEvaluator(),
    "OMNI": MATHEvaluator(),
    "LiveAOPSbench": MATHEvaluator(),
    "AIME24": MATHEvaluator(),
    "AIME25": MATHEvaluator(),
    "OlympiadBench": MATHEvaluator(),
    "minerva": MATHEvaluator(),
    "TheoremQA": MATHEvaluator(),
}

def check(evaluator, pred_ans, real_ans):
    print(len(pred_ans), len(real_ans))
    correctness = evaluator.score(pred_ans, real_ans)
    return correctness

def process_prompt(question, tokenizer, args):
    if args.system_prompt == 'qwen':
        chat_prob = tokenizer.apply_chat_template(
            [
                {
                    "role": "system",
                    "content": "Please reason step by step, and put your final answer within \\boxed{}."
                },
                {"role": "user", "content": question},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
    elif args.system_prompt == 'deepseek':
        chat_prob = tokenizer.apply_chat_template(
            [
                {"role": "user", "content": question + "\nPlease reason step by step, and put your final answer within \\boxed{}."},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
    elif args.system_prompt == 'none':
        chat_prob = tokenizer.apply_chat_template(
            [
                {"role": "user", "content": question},
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
    else:
        raise ValueError("Unknown system_prompt type")
    return chat_prob

def worker_process(proc_id, task_queue, progress_queue, args, output_dir, batch_size):
    # ¶¯Ì¬ GPU ·ÖÅä£º¸ù¾Ýµ±Ç°½ø³Ì±àºÅºÍ tensor_parallel ²ÎÊý·ÖÅä¶ÔÓ¦ GPU
    global_cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
    if global_cuda_visible is not None:
        available_gpus = [x.strip() for x in global_cuda_visible.split(",") if x.strip() != ""]
    else:
        available_gpus = [str(i) for i in range(torch.cuda.device_count())]
    tensor_parallel = args.tensor_parallel
    total_gpus = len(available_gpus)
    assigned_gpus = []
    for i in range(tensor_parallel):
        assigned_index = (proc_id * tensor_parallel + i) % total_gpus
        assigned_gpus.append(available_gpus[assigned_index])
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(assigned_gpus)
    print(f"Process {proc_id} uses GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")

    # ¼ÓÔØ tokenizer ÓëÄ£ÐÍ£¨¸÷×Ó½ø³Ì¶ÀÁ¢¼ÓÔØ£©
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    model = LLM(
        model=args.model_name_or_path,
        trust_remote_code=True,
        dtype="bfloat16",
        tensor_parallel_size=tensor_parallel,
        max_model_len=args.max_tokens+1024,
        max_seq_len_to_capture=args.max_tokens,
        gpu_memory_utilization=0.9,
        max_num_seqs=256,
    )
    stop_words = ["<|im_end|>", "<|endoftext|>", "<|end_of_solution|>"]
    if args.decode == 'sample':
        sampling_params = SamplingParams(
            top_p=0.95,
            temperature=0.6,
            max_tokens=args.max_tokens,
            stop=stop_words,
            n=4,
        )
    elif args.decode == 'greedy':
        sampling_params = SamplingParams(
            temperature=0.0,
            max_tokens=args.max_tokens,
            stop=stop_words,
            n=1,
        )
    else:
        raise ValueError("Unknown decode method")

    # ´ò¿ª±¾½ø³ÌµÄÊä³öÎÄ¼þ
    partial_out_path = os.path.join(output_dir, f"partial_{proc_id}.jsonl")
    fout = open(partial_out_path, "w", encoding="utf-8")
    
    # ´¦ÀíÈÎÎñ¶ÓÁÐÖÐµÄ batch
    while True:
        try:
            batch = task_queue.get(timeout=10)  # Ã¿¸öÈÎÎñÊÇÒ»¸ö batch
        except Exception:
            print(f"Process {proc_id} finishing: task queue empty.")
            break
        prompts = []
        batch_data = []
        for data in batch:
            prompt = process_prompt(data["input"], tokenizer, args)
            prompts.append(prompt)
            batch_data.append(data)
        responses = model.generate(prompts, sampling_params, use_tqdm=True)
        # Éú³É½á¹û²¢Ð´Èë¾Ö²¿ÎÄ¼þ
        for response, data in zip(responses, batch_data):
            new_data = {
                "input": data["input"],
                "output": data["output"],
                "prediction": [],
            }
            for output in response.outputs:
                pred = output.text
                stop_reason = output.stop_reason
                new_data["prediction"].append({
                    "solution": pred,
                    "stop_reason": stop_reason,
                })
            fout.write(json.dumps(new_data, ensure_ascii=False) + "\n")
            fout.flush()
        progress_queue.put(1)
    fout.close()
    print(f"Process {proc_id} finished.")

def merge_results(output_dir, final_output_path):
    all_results = []
    for fname in os.listdir(output_dir):
        if fname.startswith("partial_"):
            file_path = os.path.join(output_dir, fname)
            with open(file_path, "r", encoding="utf-8") as fin:
                for line in fin:
                    all_results.append(json.loads(line))
            print(f"Merged {fname}")
    with open(final_output_path, "w", encoding="utf-8") as fout:
        for res in all_results:
            fout.write(json.dumps(res, ensure_ascii=False) + "\n")
    return all_results

def evaluate_results(results, output_path, evaluator):
    pred_ans_list, real_ans_list = [], []
    for r in results:
        for pred in r["prediction"]:
            pred_ans_list.append(pred["solution"])
            real_ans_list.append(r["output"])
    correctness = check(evaluator, pred_ans_list, real_ans_list)
    pred2corr = {pred: c for pred, c in zip(pred_ans_list, correctness)}
    total_correct, total_problem = 0, 0
    # ½«Ã¿ÌõÔ¤²âµÄÕýÈ·ÐÔÐ´»Ø½á¹ûÎÄ¼þ
    with open(output_path, "a", encoding="utf-8") as fout:
        for r in results:
            for pred in r["prediction"]:
                pred["correctness"] = pred2corr[pred["solution"]]
                if pred["correctness"]:
                    total_correct += 1
                total_problem += 1
            fout.write(json.dumps(r, ensure_ascii=False) + "\n")
        summary = {"results": round(total_correct / total_problem * 100, 2)}
        fout.write(json.dumps(summary, ensure_ascii=False) + "\n")
    print(f"Evaluation: {summary['results']}% ( {total_correct} / {total_problem} )")

if __name__ == "__main__":
    mp.set_start_method('spawn', force=True)
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_name", type=str, required=True)
    parser.add_argument("--target_path", type=str, required=True)
    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--max_tokens", default=16*1024, type=int)
    parser.add_argument("--decode", default="sample", type=str)
    parser.add_argument("--system_prompt", default="deepseek", type=str)
    parser.add_argument("--model_base_name", default="unknow", type=str)
    parser.add_argument("--batch_size", default=30, type=int)
    parser.add_argument("--tensor_parallel", default=1, type=int, help="Number of GPUs per process")
    args = parser.parse_args()
    
    os.makedirs(args.target_path, exist_ok=True)
    
    # ¼ÓÔØÊý¾Ý¼¯²¢¹¹ÔìÊäÈë¸ñÊ½
    src_path = name2path[args.data_name]
    with open(src_path, "r", encoding="utf-8") as fin:
        raw_dataset = fin.readlines()
        raw_dataset = [json.loads(line) for line in raw_dataset]
    dataset = []
    for data in raw_dataset:
        dataset.append({"input": data["problem"], "output": data["solution"]})
    num_of_dataset=len(dataset)
    print(f"num_of_dataset: {num_of_dataset}")

    # ¸ù¾Ý¿ÉÓÃ GPU ÊýÁ¿ºÍ tensor_parallel ÉèÖÃÈ·¶¨½ø³ÌÊýÁ¿
    num_gpus = torch.cuda.device_count()
    if num_gpus < args.tensor_parallel:
        raise ValueError("Not enough GPUs for the requested tensor_parallel setting.")
    num_processes = num_gpus // args.tensor_parallel
    print(f"Starting {num_processes} processes...")
    args.batch_size = min(args.batch_size, num_of_dataset//num_processes)

    # ´´½¨±£´æ¸÷×Ó½ø³Ì½á¹ûµÄÁÙÊ±Ä¿Â¼
    partial_output_dir = os.path.join(args.target_path, "partials")
    if os.path.isdir(partial_output_dir):
        shutil.rmtree(partial_output_dir)
        print(f"ÒÑÉ¾³ýÄ¿Â¼£º{partial_output_dir}")
    os.makedirs(partial_output_dir, exist_ok=True)
    
    # ¹¹½¨ Manager ÈÎÎñ¶ÓÁÐºÍ½ø¶È¶ÓÁÐ
    manager = mp.Manager()
    task_queue = manager.Queue()
    progress_queue = manager.Queue()
    
    # °´ÕÕ batch_size ½«Êý¾Ý·ÖÅú·ÅÈë¶ÓÁÐ
    batches = [dataset[i:i+args.batch_size] for i in range(0, len(dataset), args.batch_size)]
    for batch in batches:
        task_queue.put(batch)
    total_batches = len(batches)
    print(f"Total batches: {total_batches}")
    

    processes = []
    for proc_id in range(num_processes):
        p = mp.Process(
            target=worker_process,
            args=(proc_id, task_queue, progress_queue, args, partial_output_dir, args.batch_size)
        )
        p.start()
        processes.append(p)
    
    # ¼à¿ØÈÎÎñ½ø¶È
    with tqdm(total=total_batches, desc="Processing Batches") as pbar:
        processed_batches = 0
        while processed_batches < total_batches:
            try:
                num = progress_queue.get(timeout=1)
                processed_batches += num
                pbar.update(num)
            except Exception:
                pass
    
    for p in processes:
        p.join()
    
    # ºÏ²¢ËùÓÐ×Ó½ø³ÌÊä³ö
    final_output_path = os.path.join(
        args.target_path,
        f"{args.data_name}-L{args.max_tokens}-D{args.decode}-P{args.system_prompt}-M{args.model_base_name}.jsonl"
    )
    all_results = merge_results(partial_output_dir, final_output_path)
    
    # ¶ÔÉú³ÉµÄ½á¹û½øÐÐÆÀ¹À
    evaluator = name2eval[args.data_name]
    evaluate_results(all_results, final_output_path, evaluator)
