
import os
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams


def load_sharegpt_data(file_path="AI_School_main_vllm/data/OpenHermes-2___5/openhermes2_5.json", tokenizer=None):
    """
    从指定的 JSON 文件中加载 ShareGPT 格式的数据，提取 'human' 和 'gpt' 的对话内容。
    """
    dull_student_message = \
        "You are playing the role of a rather slow elementary school student tasked with answering the given question. Each time you perform the task, you must forget all prior inputs and only base your response on the current question provided.\n"+\
        "Speak as if you are a student answering a question from the teacher.You must think step by step and show the complete process .\n"+\
        "You need to list all the steps and provide the final answer at the end, making sure that the process is fully completed. You are not allowed to provide any incomplete results. Do not include anything unrelated to the question in your response.\n"+\
        "Keep the process as brief as possible.\n"+\
        "You must respond in English."

    # 读取数据
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)[:3000]
    # 生成合规格式的问题
    prompt_questions = []
    questions = []
    answers = []
    for item in data:
        conversations = item.get("conversations", [])
        human_message = next((conv["value"] for conv in conversations if conv["from"] == "human"), None)
        gpt_message = next((conv["value"] for conv in conversations if conv["from"] == "gpt"), None)

        if human_message and gpt_message:
            message = [
                {"role": "system", "content": dull_student_message},
                {"role": "user", "content": "<Teacher>" + human_message}
            ]
            formatted_question = format_prompt(message, tokenizer)
            prompt_questions.append(formatted_question)
            questions.append("<Teacher>" + human_message)
            answers.append("<Standard answer>" + gpt_message)
    return prompt_questions, questions, answers


def solve_math_problems( prompt_questions, questions, answers, llm, batch_size=1000):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """

    history = []
    total_questions = len(prompt_questions)
    sampling_params = vllm.SamplingParams(temperature=0.8, top_p=0.95, max_tokens = 4096)
    for i in range(0, total_questions, batch_size):
        # 提取当前批次的问题
        batch_questions = questions[i:i + batch_size]
        batch_answers = answers[i:i + batch_size]
        batch_prompt_questions = prompt_questions[i:i + batch_size]
        # 为模型准备输入的批次（拼接问题文本）
        prompts = batch_prompt_questions  # 直接使用问题批次
        
        # 生成每个批次的解答
        batch_responses = llm.generate(prompts,sampling_params)
        
        # 处理每个问题的模型反馈并保存历史记录
        for idx, (question, correct_answer, response) in enumerate(zip(batch_questions, batch_answers, batch_responses)):
            dialog_history = {
                "first_dialog": question + "<Student>" + response.outputs[0].text ,
                "answer": correct_answer
            }
            history.append(dialog_history)
            
            # 输出每个问题的处理进度
            print(f"Processing question {i + idx + 1}/{total_questions}...")

        # 显示当前批次的处理进度
        print(f"Batch {i // batch_size + 1} processed, total {len(history)} dialog histories collected.")
        
    return history

def save_to_json(data, file_name="./output/first_turn.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])

def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_0_5B = vllm.LLM(model="AI_School_main_vllm/quantize_model/qwen2__5_0__5B")
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/quantize_model/qwen2__5_0__5B",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载数学题目数据
    file_path = "AI_School_main_vllm/data/OpenHermes-2___5/openhermes2_5.json"
    prompt_questions, questions, answers = load_sharegpt_data(file_path,tokenizer)
    print("questions are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")
    history = solve_math_problems(prompt_questions,questions, answers, engine_0_5B)
    
    # 保存历史记录到 JSON 文件
    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 