
import os
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams

def load_math_questions(file_path="AI_School_main_vllm/data/OpenHermes-2___5/openhermes2_5.json",tokenizer = None):
    """
    读取 Parquet 文件，获取 'question' 和 'answer' 列，并返回前5000行数据。
    """
    teacher_message = \
            "You are a teacher responsible for guiding students' learning. You will receive a question and generate your response based on the following rules:\n" + \
            "(1) Your response should be in a single paragraph, and first explain the question to the student.\n" + \
            "(2) When you receive a question, you should first explain the question to the student, then provide an approach without final answer.\n" + \
            "(3) You only need to explain the question without any elaboration or modifications, and you are not allowed to provide the final result. The process should be left to the student.\n"+\
            "(4) You must respond in English."
    # 读取 JSON 数据
    # 读取数据
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)[6000:9000]
    # 生成合规格式的问题
    prompt_questions = []
    answers = []
    questions = []

    for item in data:
        conversations = item.get("conversations", [])
        human_message = next((conv["value"] for conv in conversations if conv["from"] == "human"), None)
        gpt_message = next((conv["value"] for conv in conversations if conv["from"] == "gpt"), None)

        messages = [
            {"role": "system", "content": teacher_message},
            {"role": "user", "content": "<Teacher>" + "Please slove the following problem:" +human_message}
        ]
        formatted_question = format_prompt(messages, tokenizer)
        prompt_questions.append(formatted_question)
        questions.append("<Question>" + human_message)
        answers.append("<Standard answer>" + gpt_message)

    return prompt_questions, questions, answers

def solve_math_problems(prompt_questions, questions, answers, llm, batch_size=3000):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """
    history = []
    total_questions = len(prompt_questions)
    sampling_params = vllm.SamplingParams(temperature=0.8, top_p=0.95, max_tokens = 4096)
    for i in range(0, total_questions, batch_size):
        # 提取当前批次的问题
        batch_prompt_questions = prompt_questions[i:i + batch_size]
        batch_answers = answers[i:i + batch_size]
        batch_questions = questions[i:i + batch_size]
        # 为模型准备输入的批次（拼接问题文本）
        prompts = batch_prompt_questions  # 直接使用问题批次
        
        # 生成每个批次的解答
        batch_responses = llm.generate(prompts,sampling_params)
        
        # 处理每个问题的模型反馈并保存历史记录
        for idx, (correct_answer, question, response) in enumerate(zip(batch_answers, batch_questions, batch_responses)):
            dialog_history = {
                "teacher": question+"<Teacher>" + response.outputs[0].text ,
                "answer": correct_answer
            }
            history.append(dialog_history)
            
            # 输出每个问题的处理进度
            print(f"Processing question {i + idx + 1}/{total_questions}...")

        # 显示当前批次的处理进度
        print(f"Batch {i // batch_size + 1} processed, total {len(history)} dialog histories collected.")
        
    return history

def save_to_json(data, file_name="./output/first_turn.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(model="AI_School_main_vllm/quantize_model/qwen2__5_14B")
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllmc/quantize_model/qwen2__5_14B",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载数学题目数据
    file_path = "AI_School_main_vllm/data/OpenHermes-2___5/openhermes2_5.json"
    prompt_questions,questions, answers = load_math_questions(file_path,tokenizer)
    print("questions are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")
    history = solve_math_problems(prompt_questions,questions, answers, engine_14B)
    
    # 保存历史记录到 JSON 文件
    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 