
import os
import string
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams
from datasets import load_dataset, concatenate_datasets
import pandas as pd

import json

from modelscope.msdatasets import MsDataset

def load_math_questions(answers_file_path, tokenizer=None, student_field = None):
    # 读取答案 JSON 文件
    with open(answers_file_path, 'r') as file:
        data = json.load(file)
    """
    从 ModelScope 的 'MATH' 数据集的 'test' 切分中提取 'problem'和'solution' 字段，
    并返回符合模型输入格式的问题列表。
    """
    # 加载 'gsm8k' 数据集的 'test' 切分
    # 加载 'MATH' 数据集的 'test' 切分
    categories = [
        'algebra',
        'counting_and_probability',
        'geometry',
        'intermediate_algebra',
        'number_theory',
        'prealgebra',
        'precalculus'
    ]
    datasets = [
        load_dataset("AI_School_main_vllm/data/eval_datasets/hendrycks_math", category, split="test")
        for category in categories
    ]

    ds = concatenate_datasets(datasets)


    # 生成合规格式的问题
    prompt_questions = []
    for entry, llm_answer in zip(ds,data):
        problem = entry.get('problem', '')
        answer = entry.get('solution', '')
        user_prompt = f"""Question: {problem}

        "Standard answer": {answer}

        LLM answer: {llm_answer.get(student_field, '')}

        Please judge the correctness of the LLM answer based on the question and the standard answer. If it is correct, output a <1> at the end, and if it is wrong, output a <0> at the end."""

        messages = [
            {"role": "system", "content": "You are a strict math teacher and you need to judge the correctness of LLM's answers based on the questions and standard answers."},
            {"role": "user", "content": user_prompt}
        ]

        formatted_question = format_prompt(messages, tokenizer)
        prompt_questions.append(formatted_question)

    return prompt_questions



def solve_math_problems( prompt_questions, llm):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """

    history = []

    sampling_params = vllm.SamplingParams(temperature=0.2, top_p=0.95, max_tokens = 1024)

    # 生成每个批次的解答
    batch_responses = llm.generate(prompt_questions,sampling_params)
        
    # 处理每个问题的模型反馈并保存历史记录
    for idx, response in enumerate(batch_responses):
        dialog_history = response.outputs[0].text
        history.append(dialog_history)
        
    return history

def save_to_json(data, file_name="./JUDGE_EVAL.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(model="/mnt/home/user11/yl/Code_And_Data/AI_School_main_vllm/quantize_model/qwen2__5_14B", gpu_memory_utilization=0.95)
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "/mnt/home/user11/yl/Code_And_Data/AI_School_main_vllm/quantize_model/qwen2__5_14B",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载数学题目数据
    prompt_questions = load_math_questions("MATHQA_EVAL.json", tokenizer, "answer")
    print("judgers are ready")
    # 批量解题，获取历史记录
    print("Starting to judge math problems...")
    history = solve_math_problems(prompt_questions, engine_14B)
    history1 = []
    # 保存历史记录到 JSON 文件
    for responseA in history:
        dialog_history = {
            "answer": responseA
        }
        history1.append(dialog_history)

    save_to_json(history1)

if __name__ == "__main__":
    print ("start")
    main() 