
import os
import string
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams
from datasets import load_dataset, concatenate_datasets
import pandas as pd

import json

from modelscope.msdatasets import MsDataset

def load_math_questions(tokenizer=None):
    """
    从 ModelScope 的 'MATH' 数据集的 'test' 切分中提取 'problem' 字段，
    并返回符合模型输入格式的问题列表。
    """
    categories = [
        'algebra',
        'counting_and_probability',
        'geometry',
        'intermediate_algebra',
        'number_theory',
        'prealgebra',
        'precalculus'
    ]
    datasets = [
        load_dataset("AI_School_main_vllm/data/eval_datasets/hendrycks_math", category, split="test")
        for category in categories
    ]

    ds = concatenate_datasets(datasets)

    # 生成合规格式的问题
    prompt_questions = []
    first_messages = []
    for entry in ds:
        problem = entry.get('problem', '')

        user_prompt = f"""<Teacher> Question: {problem}

        Please solve the above math problem in detailed steps and place the final answer at the end."""

        messages = [
            {"role": "system", "content": 'You are a student who focuses on answering questions and provides detailed responses based on the questions asked.'},
            {"role": "user", "content": user_prompt}
        ]

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_questions.append(formatted_question)
        first_messages.append(messages)
    return prompt_questions, ds, first_messages



def solve_math_problems( prompt_questions, llm, use_temperature = None):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """

    history = []
    sampling_params = vllm.SamplingParams(temperature=use_temperature, top_p=0.95, max_tokens = 4096)
    # 生成每个批次的解答
    batch_responses = llm.generate(prompt_questions,sampling_params)
        
    # 处理每个问题的模型反馈并保存历史记录
    for response in batch_responses:
        dialog_history = response.outputs[0].text
        history.append(dialog_history)
        
    return history

def save_to_json(data, file_name="./MATHQA_EVAL.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(
        model="AI_School_main_vllm/final_model/merge_model/MA_Gen_LoRA_llama_model-base", 
        gpu_memory_utilization=0.95
        )
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/final_model/merge_model/MA_Gen_LoRA_llama_model-base",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载数学题目数据
    prompt_questions, ds, first_messages = load_math_questions(tokenizer)
    print(prompt_questions[0])
    print("questions are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")

    history1 = solve_math_problems(prompt_questions, engine_14B, use_temperature = 0.2)
    print("first turn end.")
    prompt_student = []
    for index, entry in enumerate(ds):
        studentA_answer = history1[index]


        messages_new = [
            {"role": "assistant", "content": "<Student C>" + studentA_answer},
            {"role": "user", "content": "<Teacher> Well, please check again if your answer contains any errors. If there are no mistakes, just repeat your entire answers from before. If the previous solution is wrong, provide a new correct solution and answer."}
        ]
        messages = first_messages[index] + messages_new

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_student.append(formatted_question)
    print(prompt_student[0])

    history2 = solve_math_problems(prompt_student, engine_14B, use_temperature = 0.2)

    history = []
    for responseA, responseC in zip(history1, history2):
        dialog_history = {
            "<StudentA>": responseA,
            "<StudentC>": responseC
        }
        history.append(dialog_history)

    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 