
import os
import time
import pandas as pd
from transformers import AutoTokenizer
import json
import vllm
from vllm import LLM, SamplingParams

def load_math_questions(file_path=None, tokenizer = None):
    """
    读取 JSON 文件，提取 'first_dialog' 和 'answer' 字段，并返回。
    """
    # 读取 JSON 文件
    with open(file_path, 'r') as f:
        data = json.load(f)
        student_message = \
            "You are a diligent student. You need to reason through the problem and derive the final result based on the given question and answer, following these specific rules:\n" + \
            "(1) The answer should be expressed in a single natural paragraph.\n" + \
            "(2) When you receive a question provided by the teacher, you should carefully analyze the problem and ensure the answer aligns with the standard solution.\n" + \
            "(3) Do not introduce any excessively difficult external knowledge in your response. Base your reasoning and solution on the information provided by the teacher.\n" + \
            "(4) You must provide the detailed process to reach the final answer, ensuring the solution is logically clear and reasonable.\n" + \
            "(5) You must respond in English."
    # 提取 'first_dialog' 和 'answer' 字段
    first_dialogs = [item.get('teacher') for item in data[:3000]]  # 取前 5000 个
    answers = [item.get('answer') for item in data[:3000]]  # 取前 5000 个
    
    # 将 student_message 拼接到每个 first_dialog 的前面
    modified_first_dialogs = []
    for dialog in first_dialogs:
        messages = [
            {"role": "system", "content": student_message},
            {"role": "user", "content": dialog}
        ]
        # 使用 tokenizer 格式化
        formatted_question = format_prompt(messages, tokenizer)
        modified_first_dialogs.append(formatted_question)
       
    return modified_first_dialogs, first_dialogs, answers

def solve_math_problems(modified_dialogs, dialogs, answers, llm, batch_size=3000):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """
    history = []
    total_dialogs = len(dialogs)
    sampling_params = vllm.SamplingParams(temperature=0.2, top_p=0.95, max_tokens = 4096)
    for i in range(0, total_dialogs, batch_size):
        # 提取当前批次的第一轮对话
        batch_dialogs = dialogs[i:i + batch_size]
        batch_answers = answers[i:i + batch_size]
        batch_modified_dialogs = modified_dialogs[i:i + batch_size]
        # 为模型准备输入的批次（拼接问题文本）
        prompts = batch_modified_dialogs  # 直接使用问题批次
        
        # 生成每个批次的解答
        batch_responses = llm.generate(prompts,sampling_params)
        
        # 处理每个问题的模型反馈并保存历史记录
        for idx, (dialog, correct_answer, response) in enumerate(zip(batch_dialogs, batch_answers, batch_responses)):
            dialog_history = {
                "first_dialog": dialog ,
                "student_answer": "<Student>" + response.outputs[0].text,
                "answer": correct_answer
            }
            history.append(dialog_history)
            
            # 输出每个问题的处理进度
            print(f"Processing question {i + idx + 1}/{total_dialogs}...")

        # 显示当前批次的处理进度
        print(f"Batch {i // batch_size + 1} processed, total {len(history)} dialog histories collected.")
        
    return history

def save_to_json(data, file_name="./output/second_turn.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")
def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])

def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(model="AI_School_main_vllm/quantize_model/qwen2__5_14B")
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/quantize_model/qwen2__5_14B",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载上轮对话数据
    file_path = "AI_School_main_vllm/agentclass/expand/openhermes/sbatch/output/first_turn.json"
    modified_dialogs,dialogs,answers = load_math_questions(file_path, tokenizer)
    print("dialogs are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")
    history = solve_math_problems(modified_dialogs, dialogs, answers, engine_14B)
    
    # 保存历史记录到 JSON 文件
    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 