
import os
import string
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams
from datasets import load_dataset, concatenate_datasets
import pandas as pd

import json

def load_math_questions(file_path, tokenizer=None):
    """
    读取 JSON 文件，获取 'Problem' 和 'options' 列，并返回符合模型输入格式的问题列表。
    """
    subcategories = [
        'abstract_algebra',
        'anatomy',
        'astronomy',
        'business_ethics',
        'clinical_knowledge',
        'college_biology',
        'college_chemistry',
        'college_computer_science',
        'college_mathematics',
        'college_medicine',
        'college_physics',
        'computer_security',
        'conceptual_physics',
        'econometrics',
        'electrical_engineering',
        'elementary_mathematics',
        'formal_logic',
        'global_facts',
        'high_school_biology',
        'high_school_chemistry',
        'high_school_computer_science',
        'high_school_european_history',
        'high_school_geography',
        'high_school_government_and_politics',
        'high_school_macroeconomics',
        'high_school_mathematics',
        'high_school_microeconomics',
        'high_school_physics',
        'high_school_psychology',
        'high_school_statistics',
        'high_school_us_history',
        'high_school_world_history',
        'human_aging',
        'human_sexuality',
        'international_law',
        'jurisprudence',
        'logical_fallacies',
        'machine_learning',
        'management',
        'marketing',
        'medical_genetics',
        'miscellaneous',
        'moral_disputes',
        'moral_scenarios',
        'nutrition',
        'philosophy',
        'prehistory',
        'professional_accounting',
        'professional_law',
        'professional_medicine',
        'professional_psychology',
        'public_relations',
        'security_studies',
        'sociology',
        'us_foreign_policy',
        'virology',
        'world_religions'
    ]

    datasets_list = []
    for subcat in subcategories:
        try:
            dataset = load_dataset(
                path=file_path,
                name=subcat,
                split='test'
            )
            datasets_list.append(dataset)
        except Exception as e:
            print(f"加载子类别 {subcat} 时出错: {e}")

    if not datasets_list:
        print("未能加载任何子类别的测试集。请检查数据目录和子类别名称是否正确。")
        return []

    # 合并所有子类别的测试集
    data = concatenate_datasets(datasets_list)
    # 生成合规格式的问题
    first_messages = []
    prompt_questions = []
    for entry in data:
        problem = entry.get('question', '')
        options = entry.get('choices', [])
        # 根据选项数量生成对应的字母（A, B, C, D, ...）
        option_labels = string.ascii_uppercase[:len(options)]  # 生成 A, B, C, D,... 的字母列表
        options_str = '\n'.join([f"{option_labels[i]}. {opt}" for i, opt in enumerate(options)])
        user_prompt = f"""<Teacher> Question: {problem}

        Options:
        {options_str}

        Please select the correct answer from {', '.join(option_labels)}. Finally, provide your answer in the format [x], where x is the index of the correct option."""

        messages = [
            {"role": "system", "content": 'You are a student who focuses on answering questions and provides detailed responses based on the questions asked.'},
            {"role": "user", "content": user_prompt}
        ]

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_questions.append(formatted_question)
        first_messages.append(messages)

    return prompt_questions, data, first_messages


def solve_math_problems( prompt_questions, llm):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """

    history = []

    sampling_params = vllm.SamplingParams(temperature=0.2, top_p=0.95, max_tokens = 4096)

    # 生成每个批次的解答
    batch_responses = llm.generate(prompt_questions,sampling_params)
        
    # 处理每个问题的模型反馈并保存历史记录
    for idx, response in enumerate(batch_responses):
        dialog_history = {
            "answer":response.outputs[0].text
        }
        history.append(dialog_history)
        
    return history

def save_to_json(data, file_name="./LLAMA3_MATHQA_EVAL.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(model="AI_School_main_vllm/final_model/merge_model/MA_Gen_LoRA_llama_model-base", gpu_memory_utilization=0.95)
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/final_model/merge_model/MA_Gen_LoRA_llama_model-base",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载数学题目数据
    file_path = "AI_School_main_vllm/data/eval_datasets/mmlu"
    prompt_questions, data, first_messages = load_math_questions(file_path,tokenizer)
    print("questions are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")
    history1 = solve_math_problems(prompt_questions, engine_14B)
    print("first turn end.")
    prompt_student = []
    for index, entry in enumerate(data):
        studentA_answer = history1[index]


        messages_new = [
            {"role": "assistant", "content": "<Student C>" + studentA_answer["answer"]},
            {"role": "user", "content": "<Teacher> Well, please check if your answer contains any errors. Finally, provide your answer in the format [x], where x is the index of the correct option."}
        ]
        messages = first_messages[index] + messages_new

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_student.append(formatted_question)
    print(prompt_student[0])

    history2 = solve_math_problems(prompt_student, engine_14B)

    history = []
    for responseA, responseC in zip(history1, history2):
        dialog_history = {
            "<StudentA>": responseA,
            "<StudentC>": responseC
        }
        history.append(dialog_history)

    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 