
import os
import string
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams
from datasets import load_dataset, concatenate_datasets
import pandas as pd

import json

from modelscope.msdatasets import MsDataset

def load_math_questions(tokenizer=None):
    """
    从 ModelScope 的 'MATH' 数据集的 'test' 切分中提取 'problem' 字段，
    并返回符合模型输入格式的问题列表。
    """
    # 加载 'gsm8k' 数据集的 'test' 切分
    categories = [
        'algebra',
        'counting_and_probability',
        'geometry',
        'intermediate_algebra',
        'number_theory',
        'prealgebra',
        'precalculus'
    ]
    datasets = [
        load_dataset("AI_School_main_vllm/data/eval_datasets/hendrycks_math", category, split="test")
        for category in categories
    ]

    ds = concatenate_datasets(datasets)
    # 生成合规格式的问题
    prompt_questions = []
    for entry in ds:
        problem = entry.get('problem', '')

        user_prompt = f"""<Teacher> Question: {problem}

        Please solve the above math problem in detailed steps and place the final answer at the end."""

        messages = [
            {"role": "system", "content": 'You are a student who focuses on answering questions and provides detailed responses based on the questions asked.'},
            {"role": "user", "content": user_prompt}
        ]

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_questions.append(formatted_question)

    return prompt_questions



def solve_math_problems( prompt_questions, llm):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """

    history = []

    sampling_params = vllm.SamplingParams(temperature=0.2, top_p=0.95, max_tokens = 4096)

    # 生成每个批次的解答
    batch_responses = llm.generate(prompt_questions,sampling_params)
        
    # 处理每个问题的模型反馈并保存历史记录
    for idx, response in enumerate(batch_responses):
        dialog_history = {
            "answer":response.outputs[0].text
        }
        history.append(dialog_history)
        
    return history

def save_to_json(data, file_name="./MATHQA_EVAL.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(model="AI_School_main_vllm/final_model/merge_model/ori_data_LoRA_mistral_model-base", gpu_memory_utilization=0.95, max_model_len=16384)
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/final_model/merge_model/ori_data_LoRA_mistral_model-base",
        trust_remote_code=True,
        repo_type="model"
    )
    print("LLM is ready")
    # 加载数学题目数据
    prompt_questions = load_math_questions(tokenizer)
    print("questions are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")
    history = solve_math_problems(prompt_questions, engine_14B)
    
    # 保存历史记录到 JSON 文件
    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 