import os
import time
import pandas as pd
import json
import vllm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

def load_math_questions(file_path=None, tokenizer = None):
    """
    读取 JSON 文件，提取 'first_dialog' 和 'answer' 字段，并返回。
    """
    # 读取 JSON 文件
    with open(file_path, 'r') as f:
        data = json.load(f)
        student_message = \
            "You are a student who admits mistakes and corrects them. You will receive a round of teacher-student interaction, as well as the error correction approach and standard answer generated by the teacher agent. Based on the following rules, generate your response:\n"+\
            "(1) Based on the teacher-student interaction, you should immerse yourself in the role of a student who made mistakes. Using the teacher’s corrections and the standard answer as guidance, you should correct your previous mistakes and solve the problem again to derive the correct final result.\n"+\
            "(2) In any input scenario, you must not simulate both the teacher and student dialogue at the same time. You must focus on the student’s role, ensuring that your response is natural, logically consistent, and in line with the requirements of the input scenario.\n"+\
            "(3) The teacher’s responses are handled by the dedicated teacher agent. Your role is limited to playing the student agent. Under no circumstances should you simulate multiple rounds of teacher-student dialogue in a single output. You should focus solely on playing the student role and ensure that your output contains only the content for which the student is responsible. Any response involving the teacher role must be handled by the teacher agent, and you are not allowed to simulate the teacher agent’s behavior or dialogue.\n"+\
            "(4) You must respond in English."
    # 提取 'first_dialog' 和 'answer' 字段
    # 提取 'first_dialog' 和 'teacher_correct' 并拼接
    prompt_dialogs = [
        student_message +item.get('first_dialog', '')+item.get('teacher_correct', '')  + item.get("answer", "")
        for item in data[:5000]
    ]
    second_dialogs = [
        item.get('first_dialog', '') + item.get('teacher_correct', '')
        for item in data[:5000]
    ]
    # 将 teacher_message 拼接到每个 first_dialog 的前面
    prompt_second_dialogs = []
    for dialog in prompt_dialogs:
        messages = [
            {"role": "system", "content": student_message},
            {"role": "user", "content": dialog}
        ]
        # 使用 tokenizer 格式化
        formatted_question = format_prompt(messages, tokenizer)
        prompt_second_dialogs.append(formatted_question)
    return prompt_second_dialogs, second_dialogs

def solve_math_problems(prompt_second_dialogs, second_dialogs, llm, batch_size=50):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """
    history = []
    total_dialogs = len(prompt_second_dialogs)
    sampling_params = vllm.SamplingParams(temperature=0.2, top_p=0.95, max_tokens = 4096)
    for i in range(0, total_dialogs, batch_size):
        # 提取当前批次的第一轮对话
        batch_prompt_dialogs = prompt_second_dialogs[i:i + batch_size]
        batch_dialogs = second_dialogs[i:i + batch_size]
        # 为模型准备输入的批次（拼接问题文本）
        prompts = batch_prompt_dialogs  # 直接使用问题批次
        
        # 生成每个批次的解答
        batch_responses = llm.generate(prompts,sampling_params)
        
        # 处理每个问题的模型反馈并保存历史记录
        for idx, (second_dialog, response) in enumerate(zip(batch_dialogs, batch_responses)):
            dialog_history = {
                "Instruct": second_dialog ,
                "student_correct": "<Student>" + response.outputs[0].text,
            }
            history.append(dialog_history)
            
            # 输出每个问题的处理进度
            print(f"Processing question {i + idx + 1}/{total_dialogs}...")

        # 显示当前批次的处理进度
        print(f"Batch {i // batch_size + 1} processed, total {len(history)} dialog histories collected.")
        
    return history

def save_to_json(data, file_name="./output/final_turn.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")
def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(model="AI_School_main_vllm/quantize_model/qwen2__5_14B")
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/quantize_model/qwen2__5_14B",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载上轮对话数据
    file_path = "AI_School_main_vllm/agentclass/make_error/code/sbatch/output/second_turn.json"
    prompt_second_dialogs, second_dialogs = load_math_questions(file_path, tokenizer)
    print("dialogs are ready")
    # 批量解题，获取历史记录
    print("Starting to solve code problems...")
    history = solve_math_problems(prompt_second_dialogs, second_dialogs, engine_14B)
    
    # 保存历史记录到 JSON 文件
    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 