import argparse
from transformers import AutoTokenizer
import json,os
from datetime import datetime
import string
from src import utils
from tqdm import tqdm
from src import prompt,loader,mutil_inference_question_api
def get_args():
    parser = argparse.ArgumentParser(description="VLLM Inference Script")
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="Path to the model directory or model name.",
    )
    
    parser.add_argument(
        "--eval_func_list",
        type=str,
        required=True,
        default="eval_math_500",
        help="Path to the output file to save the generated text.",
    )
    
    parser.add_argument(
        "--language",
        type=str,
        default="zh",
        choices=["zh", "en"],
        help="Language for the prompts.",     
    )
    
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="Path to the output dir to save the generated text.",
    )
    
    parser.add_argument(
        "--gpu_num",
        type=int,
        default=1,
        required=True,
        help="gpu num"
    )
    
    parser.add_argument(
        "--inference_type",
        type=str,
        required=False,
        default="model",  # 默认值设置为"model"
        choices=["model", "api"],  # 校验合法值
        help="Inference type, choose between 'model' and 'api'. Defaults to 'model'."
    )

    parser.add_argument(
        "--sample_num",
        required=False,
        type=int,
        default=200,  # 默认值200
        help="Number of samples to process. Default is 200."
    )

    parser.add_argument(
        "--max_tokens",
        required=False,
        type=int,
        default=16384,
        help="max_tokens"
    )
    
    return parser.parse_args()

def get_eval_result(data:list[loader.BatchOutput]):
    tt = 0
    correct_num = 0
    correct_len_list=[]
    for index,item in tqdm(enumerate(data)):        
        # extract_answers = utils.extract_batch_answers(item.generated_text,  len(item.id_list))
        # if extract_answers is None:
        #     utils.extract_batch_answers(item.generated_text,  len(item.id_list))
        #     continue
        for index, gold in enumerate(item.origin_correct_answer_list):
            is_correct = utils.is_answer_correct( item.generated_text,f"{gold}",)
            tt += 1
            if is_correct:
                correct_len_list.append(len(item.generated_text))
                correct_num += 1
    dt = {
        'data_len': len(data),
        'correct_num': correct_num,
        'total_num': tt,
        'correct_rate_for_extract': f"{correct_num*100.0/tt:.2f}%",
        'correct_rate': f"{correct_num*100.0/len(data):.2f}%",
        'avg_output': sum(len(item.generated_text) for item in data) /  len(data),
        'avg_correct_output':sum(correct_len_list)/len(correct_len_list),
        'avg_output_token_len':sum(item.generated_text_token_len for item in data) /  len(data),
    }
    return dt

def eval_dataset(model,tokenizer, questions:list[loader.Question],generate_output_file,result_output_file,language,args):
    batch_output=utils.batch_inference_eval(
        model=model,
        tokenizer=tokenizer,
        questions=questions,
        output_path=generate_output_file,
        batch_size=1,
        language=language,
        max_tokens=args.max_tokens,
    )
    print(f"Generated {len(batch_output)} results.")
    print(f"Saved model output to {generate_output_file}")
    # 评估
    print("Evaluating...")
    dt = get_eval_result(batch_output)
    utils.write_json(dt, result_output_file)
    print(f"Saved eval result to {result_output_file}")
    return dt

def eval_math_500(model,tokenizer, language, sample_num,output_dir,args):
    questions=loader.load_math_500()[:sample_num]
    output_file = f"{output_dir}/math_500_generated.jsonl"  # 文件名
    eval_result_file = f"{output_dir}/math_500_eval.json"
    return eval_dataset(model,tokenizer,questions,output_file,eval_result_file,language,args)

def eval_gsm_8k(model,tokenizer, language, sample_num,output_dir,args):
    questions=loader.load_gsm_8k(data_split='test')[:sample_num]
    output_file = f"{output_dir}/gsm_8k_generated.jsonl"  # 文件名
    eval_result_file = f"{output_dir}/gsm_8k_eval.json"
    return eval_dataset(model,tokenizer,questions,output_file,eval_result_file,language,args)

def eval_aqua_rat(model,tokenizer, language, sample_num,output_dir,args):
    questions=loader.load_aqua_rat(data_split='test')[:sample_num]
    output_file = f"{output_dir}/aqua_rat_generated.jsonl"  # 文件名
    eval_result_file = f"{output_dir}/aqua_rat_eval.json"
    return eval_dataset(model,tokenizer,questions,output_file,eval_result_file,language,args)

def eval_aime_2024(model,tokenizer, language, sample_num,output_dir,args):
    questions=loader.load_aime_2024()[:sample_num]
    output_file = f"{output_dir}/aime_2024_generated.jsonl" 
    eval_result_file = f"{output_dir}/aime_2024_eval.json"
    return eval_dataset(model,tokenizer,questions,output_file,eval_result_file,language,args)

def eval_aime_2025(model,tokenizer, language, sample_num,output_dir,args):
    questions=loader.load_aime_2025()[:sample_num]
    output_file = f"{output_dir}/aime_2025_generated.jsonl"  
    eval_result_file = f"{output_dir}/aime_2025_eval.json"
    return eval_dataset(model,tokenizer,questions,output_file,eval_result_file,language,args)

def eval_gpqa_diamond(model,tokenizer, language, sample_num,output_dir,args):
    questions=loader.load_GPQA_diamond()[:sample_num]
    output_file = f"{output_dir}/gpqa_diamond_generated.jsonl"  
    eval_result_file = f"{output_dir}/gpqa_diamond_eval.json"
    return eval_dataset(model,tokenizer,questions,output_file,eval_result_file,language,args)

def eval_amc_23(model,tokenizer, language, sample_num,output_dir,args):
    questions=loader.load_AMC_23()[:sample_num]
    output_file = f"{output_dir}/amc_23_generated.jsonl"  
    eval_result_file = f"{output_dir}/amc_23_eval.json"
    return eval_dataset(model,tokenizer,questions,output_file,eval_result_file,language,args)

eval_func_dict = {
    'eval_math_500': eval_math_500,
    'eval_gsm_8k':eval_gsm_8k,
    'eval_aqua_rat':eval_aqua_rat,
    'eval_aime_2024':eval_aime_2024,
    'eval_aime_2025':eval_aime_2025,
    'eval_gpqa_diamond':eval_gpqa_diamond,
    'eval_amc_23':eval_amc_23,
    
}


def main():
    args = get_args()
    print(args)
    eval_func_list = args.eval_func_list.split(',')
    print(f"eval functions: {eval_func_list}")
    tensor_parallel_size=args.gpu_num

    distill_models = ["DeepSeek-R1-Distill-Qwen-1.5B", "DeepSeek-R1-Distill-Qwen-7B"]
    if any(model in args.model_path for model in distill_models) and args.gpu_num > 4:
        tensor_parallel_size = 4
    model = utils.load_model(args.model_path,tensor_parallel_size=tensor_parallel_size)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    print(f"finish loading model from {args.model_path}")
    dt_list=[]
    for func_name in eval_func_list:
        if func_name in eval_func_dict:
            eval_func = eval_func_dict[func_name]
            dt=eval_func(model,tokenizer, args.language, args.sample_num, args.output_dir,args)
            dt_list.append({
                'func_name':func_name,
                **dt,
            })
            print(f"eval result for {func_name}:")
            for key, value in dt.items():
                print(f"\t{key}: {value}")
    utils.write_json(dt_list, f"{args.output_dir}/eval_result.json")
    print(f"eval result {dt_list}")
    print(f"Saved eval result to {args.output_dir}/eval_result.json")

if __name__ == "__main__":
    main()