import os
import time
import pandas as pd
import json
import vllm
import torch
import pickle
import numpy as np
import random
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
def load_math_questions(file_path=None, tokenizer = None):
    """
    读取 JSON 文件，提取 'first_dialog' 和 'answer' 字段，并返回。
    """
    # 读取 JSON 文件
    with open(file_path, 'r') as f:
        data = json.load(f)
        student_message = \
            "You are a diligent student. You need to reason through the problem and derive the final result based on the given question and answer, following these specific rules:\n" + \
            "(1) The answer should be expressed in a single natural paragraph.\n" + \
            "(2) When you receive a question provided by the teacher, you should carefully analyze the problem and ensure the answer aligns with the standard solution.\n" + \
            "(3) Do not introduce any excessively difficult external knowledge in your response. Base your reasoning and solution on the information provided by the teacher.\n" + \
            "(4) You must provide the detailed calculation process to reach the final answer, ensuring the solution is logically clear and reasonable.\n" + \
            "(5) You must respond in English."
    # 提取 'first_dialog' 和 'answer' 字段
    # 提取 'first_dialog' 和 'teacher_correct' 并拼接
    first_dialogs = [
        item.get('first_dialog', '')+item.get('student_answer', '')
        for item in data[:3000]
    ]
    # 提取 <Teacher> 之前的部分作为 questions
    questions = [dialog.split("<Teacher>")[0] for dialog in first_dialogs]
    # first_dialogs = [
    #     student_message +item.get('first_dialog', '')+item.get('student_answer', '')
    #     for item in data[:5000]
    # ]
    # 将 teacher_message 拼接到每个 first_dialog 的前面

    #开始计算相似度高的题目列表similiar_questions

    model = SentenceTransformer('AI_School_main_vllm/Embedding_model/all-MiniLM-l6-v2')
    # 存储解析后题库的文件路径
    parsed_file_path = 'AI_School_main_vllm/agentclass/expand/openhermes/Embedding_bank/parsed_knowledge_points.pkl'

    # 判断是否已经解析过，若存在已解析文件，则直接加载
    if os.path.exists(parsed_file_path):
        print("已找到解析后的题库文件，直接加载...")
        with open(parsed_file_path, 'rb') as file:
            knowledge_points = pickle.load(file)
    else:
        print("未找到解析后的题库文件，开始解析...")
        knowledge_points = []
        processed_instructions = set()  # 用于避免重复解析


        # 读取 Parquet 文件并解析内容
        file_path = 'AI_School_main_vllm/data/OpenHermes-2___5/openhermes2_5.json'
        with open(file_path, 'r', encoding='utf-8') as f:
            df = json.load(f)
        # 选择从第9000行开始
        start_row = 9000
        df_filtered = df[start_row:]
        for item in df_filtered:
            conversations = item.get("conversations", [])
            question = next((conv["value"] for conv in conversations if conv["from"] == "human"), None)
            answer = next((conv["value"] for conv in conversations if conv["from"] == "gpt"), None)
            # 将 question 和 answer 作为元组添加到 knowledge_points 中
            if (question, answer) not in processed_instructions:
                knowledge_points.append((question, answer))  # 保存为元组
                processed_instructions.add((question, answer))  # 使用元组作为唯一标识，避免重复

        # 保存解析后的题库到文件
        with open(parsed_file_path, 'wb') as file:
            pickle.dump(knowledge_points, file)
        print(f"题库解析完成，共包含 {len(knowledge_points)} 个独特的题目。")
    ##开始计算其余题目嵌入向量 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)  # 将嵌入模型移动到 GPU
    # **📌 批量计算知识库嵌入**
    print("📌 计算知识库嵌入（批量处理）...")
    knowledge_questions = [q[0] for q in knowledge_points]  # 取出所有问题
    with torch.no_grad():
        knowledge_embeddings = model.encode(knowledge_questions, batch_size=32, convert_to_tensor=True).to(device)
    print(f"✅ 知识库嵌入完成，共 {len(knowledge_points)} 条数据")
    # **📌 计算输入问题的匹配**
    print("📌 计算问题匹配（批量处理）...")
    similiar_questions = []
    similiar_answers = []
    for question in questions:
        # **计算输入问题的嵌入**
        with torch.no_grad():
            question_embedding = model.encode(question, convert_to_tensor=True).to(device)

        # **计算余弦相似度（一次性计算所有匹配）**
        cos_scores = util.cos_sim(question_embedding, knowledge_embeddings)[0]  # 计算余弦相似度
        cos_scores = cos_scores.cpu().numpy()  # 转换为 numpy 数组

        # **获取全局最相似的 3 个问题**
        top_3_indices = np.argsort(cos_scores)[-3:][::-1]  # 获取最相似的 3 个问题的索引
        best_matching_questions = [knowledge_points[i] for i in top_3_indices]

        # **去重（保持顺序）**
        best_matching_questions = list({q[0]: q for q in best_matching_questions}.values())

        # **随机选择一个匹配问题**
        random_top_question_answer = random.choice(best_matching_questions)
        random_top_question = random_top_question_answer[0]
        random_top_answer = random_top_question_answer[1]
        similiar_questions.append(random_top_question)
        similiar_answers.append(random_top_answer)
    prompt_new_questions = []
    new_questions = []
    for question,answer in zip(similiar_questions, similiar_answers):
        messages = [
            {"role": "system", "content": student_message},
            {"role": "user", "content": "<Teacher>Good job,please take a look at this similar question."+question+"<Standard answer>"+answer}
        ]
        # 使用 tokenizer 格式化
        formatted_question = format_prompt(messages, tokenizer)
        prompt_new_questions.append(formatted_question)
        new_questions.append("<Teacher>Good job,please take a look at this similar question."+question)
    return prompt_new_questions, new_questions, similiar_answers, first_dialogs

def solve_math_problems(prompt_new_questions, new_questions, similiar_answers, first_dialogs, llm, batch_size=3000):
    """
    以批次处理数学题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """
    combine_dialogs = [q + fd for q, fd in zip(first_dialogs, new_questions)]
    
    history = []
    total_dialogs = len(prompt_new_questions)
    sampling_params = vllm.SamplingParams(temperature=0.2, top_p=0.95, max_tokens = 4096)
    for i in range(0, total_dialogs, batch_size):
        # 提取当前批次的第一轮对话
        batch_new_questions = new_questions[i:i + batch_size]
        batch_prompt_new_questions = prompt_new_questions[i:i + batch_size]
        batch_combine_dialogs = combine_dialogs[i:i + batch_size]
        # 为模型准备输入的批次（拼接问题文本）
        prompts = batch_prompt_new_questions  # 直接使用问题批次
        
        # 生成每个批次的解答
        batch_responses = llm.generate(prompts,sampling_params)
        
        # 处理每个问题的模型反馈并保存历史记录
        for idx, (combine_dialog, response, answer) in enumerate(zip(batch_combine_dialogs, batch_responses, similiar_answers)):
            dialog_history = {
                "Instruct":  combine_dialog,
                "new_answer": "<Student>" + response.outputs[0].text,
                "Standard answer": answer
            }
            history.append(dialog_history)
            
            # 输出每个问题的处理进度
            print(f"Processing question {i + idx + 1}/{total_dialogs}...")

        # 显示当前批次的处理进度
        print(f"Batch {i // batch_size + 1} processed, total {len(history)} dialog histories collected.")
        
    return history

def save_to_json(data, file_name="./output/final_turn.json"):
    """
    将对话历史和答案保存为 JSON 文件
    """
    # 确保目录存在
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    
    with open(file_name, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Data saved to {file_name}.")
def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def main():
# 创建采样参数对象
    
    print("start......")
    # 加载本地模型（请确保模型路径正确）
    engine_14B = vllm.LLM(model="AI_School_main_vllm/quantize_model/qwen2__5_14B")
    # 初始化分词器
    tokenizer = AutoTokenizer.from_pretrained(
        "AI_School_main_vllm/quantize_model/qwen2__5_14B",
        trust_remote_code=True
    )
    print("LLM is ready")
    # 加载上轮对话数据
    file_path = "AI_School_main_vllm/agentclass/expand/openhermes/sbatch/output/second_turn.json"
    prompt_new_questions, new_questions, similiar_answers, first_dialogs = load_math_questions(file_path, tokenizer)
    print("dialogs are ready")
    # 批量解题，获取历史记录
    print("Starting to solve math problems...")
    history = solve_math_problems(prompt_new_questions, new_questions, similiar_answers, first_dialogs, engine_14B)
    
    # 保存历史记录到 JSON 文件
    save_to_json(history)

if __name__ == "__main__":
    print ("start")
    main() 
    

    