
import os
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams

import pandas as pd

import json

import json

def load_math_questions(file_path, tokenizer=None):
    """
    读取 JSONL 文件，获取 'problem' 和 'answer_option_list' 字段，并返回符合模型输入格式的问题列表。
    """
    # 读取 JSON 数据
    with open(file_path, 'r') as file:
        data = [json.loads(line) for line in file]  # 逐行读取 JSONL 数据
    prompt_questions = []
    first_messages = []
    # 生成合规格式的问题
    prompt_questions = []
    for entry in data:
        problem = entry.get('problem', '')
        options_list = entry.get('answer_option_list', [])

        # 提取选项的 aoVal 和 content，拼接成字符串
        options = "\n".join(f"{opt[0]['aoVal']}. {opt[0]['content']}" for opt in options_list if opt)

        # 计算当前题目的选项数量
        option_labels = [opt[0]['aoVal'] for opt in options_list if opt]
        option_count_str = ", ".join(option_labels)  # 例如 "A, B, C, D"

        # 生成用户提示
        user_prompt_A = f"""<Teacher> Question: {problem}

        Options:
        {options}

        Please select the answer from {', '.join(option_count_str)}. Finally, provide your answer in the format [x], where x is the index of the correct option."""

        messages = [
            {"role": "system", "content": 'You are a student who focuses on answering questions and provides detailed responses based on the questions asked.'},
            {"role": "user", "content": user_prompt_A}
        ]

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_questions.append(formatted_question)
        first_messages.append(messages)

    return prompt_questions, data, first_messages




def solve_math_problems( prompt_questions, llm, use_temperature = None):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """

    history = []
    sampling_params = vllm.SamplingParams(temperature=use_temperature, top_p=0.95, max_tokens = 4096)
    # 生成每个批次的解答
    batch_responses = llm.generate(prompt_questions,sampling_params)
        
    # 处理每个问题的模型反馈并保存历史记录
    for response in batch_responses:
        dialog_history = response.outputs[0].text
        history.append(dialog_history)
        
    return history

def save_to_json(data, file_name="./MATHQA_EVAL.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(
        model="AI_School_main_vllm/final_model/merge_model/MA_Gen_LoRA_llama_model-base", 
        gpu_memory_utilization=0.95
        )
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/final_model/merge_model/MA_Gen_LoRA_llama_model-base",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载数学题目数据
    file_path = "AI_School_main_vllm/data/eval_datasets/TAL-SCQ5K/TAL-SCQ5K-EN/test.jsonl"
    prompt_questions, data, first_messages = load_math_questions(file_path,tokenizer)

    print(prompt_questions[0])
    print("questions are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")
    history1 = solve_math_problems(prompt_questions, engine_14B, use_temperature = 0.2)
    print("first turn end.")
    prompt_student = []
    for index, entry in enumerate(data):
        studentA_answer = history1[index]


        messages_new = [
            {"role": "assistant", "content": "<Student C>" + studentA_answer},
            {"role": "user", "content": "<Teacher> Well, please check again if your answer contains any errors. Finally, provide your answer in the format [x], where x is the index of the correct option."}
        ]
        messages = first_messages[index] + messages_new

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_student.append(formatted_question)
    print(prompt_student[0])

    history2 = solve_math_problems(prompt_student, engine_14B, use_temperature = 0.2)

    history = []
    for responseA, responseC in zip(history1, history2):
        dialog_history = {
            "<StudentA>": responseA,
            "<StudentC>": responseC
        }
        history.append(dialog_history)

    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 
