
import os
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams

import pandas as pd

import json

def load_math_questions(file_path, tokenizer=None):
    """
    读取 JSON 文件，获取 'Problem' 和 'options' 列，并返回符合模型输入格式的问题列表。
    """
    # 读取 JSON 数据
    with open(file_path, 'r') as file:
        # data = [json.loads(line) for line in file]
        data = json.load(file)  # 使用 json.load() 读取整个文件

    # 生成合规格式的问题
    prompt_questions = []
    first_messages = []
    for entry in data:
        problem = entry.get('Problem', '')
        options = entry.get('options', '')
        user_prompt_A = f"""<Teacher> Question: {problem}

        Options:
        {options}

        Please select the correct answer from [a , b , c , d , e ]. Finally, provide your answer in the format [x], where x is the index of the correct option."""


        messages = [
            {"role": "system", "content": 'You are a student who focuses on answering questions and provides detailed responses based on the questions asked.'},
            {"role": "user", "content": user_prompt_A}
        ]

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_questions.append(formatted_question)
        first_messages.append(messages)

    return prompt_questions, data, first_messages


def solve_math_problems( prompt_questions, llm, use_temperature = None):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """

    history = []
    sampling_params = vllm.SamplingParams(temperature=use_temperature, top_p=0.95, max_tokens = 4096)
    # 生成每个批次的解答
    batch_responses = llm.generate(prompt_questions,sampling_params)
        
    # 处理每个问题的模型反馈并保存历史记录
    for response in batch_responses:
        dialog_history = response.outputs[0].text
        history.append(dialog_history)
        
    return history


def save_to_json(data, file_name="./MATHQA_EVAL.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型
    engine_14B = vllm.LLM(
        model="AI_School_main_vllm/final_model/merge_model/MA_Gen_LoRA_llama_model-base", 
        gpu_memory_utilization=0.95
        )
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/final_model/merge_model/MA_Gen_LoRA_llama_model-base",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载数学题目数据
    file_path = "AI_School_main_vllm/data/eval_datasets/MathQA/test.json"
    prompt_questions, data, first_messages = load_math_questions(file_path,tokenizer)

    print(prompt_questions[0])
    print("questions are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")
    history1 = solve_math_problems(prompt_questions, engine_14B, use_temperature = 0.2)
    print("first turn end.")
    prompt_student = []
    for index, entry in enumerate(data):
        studentA_answer = history1[index]


        messages_new = [
            {"role": "assistant", "content": "<Student C>" + studentA_answer},
            {"role": "user", "content": "<Teacher> Well, please check again if your answer contains any errors. Finally, provide your answer in the format [x], where x is the index of the correct option."}
        ]
        messages = first_messages[index] + messages_new

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_student.append(formatted_question)
    print(prompt_student[0])

    history2 = solve_math_problems(prompt_student, engine_14B, use_temperature = 0.2)

    history = []
    for responseA, responseC in zip(history1, history2):
        dialog_history = {
            "<StudentA>": responseA,
            "<StudentC>": responseC
        }
        history.append(dialog_history)

    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 