import sys
import os
import string
from transformers import AutoTokenizer
import time
import pandas as pd
import json
import vllm
from vllm import LLM, SamplingParams
import pandas as pd
import json
sys.path.insert(0, 'human-eval/human_eval')  # 替换为实际的路径
from data import write_jsonl, read_problems

print("start......")
# 加载本地模型（请确保模型路径正确）
engine_14B = vllm.LLM(
    model="AI_School_main_vllm/final_model/merge_model/CoT_LoRA_llama_model-base", 
    gpu_memory_utilization=0.95
    )
# 初始化分词器
tokenizer = AutoTokenizer.from_pretrained(
    "AI_School_main_vllm/final_model/merge_model/CoT_LoRA_llama_model-base",
    trust_remote_code=True
)
print("LLM is ready")

def load_code_questions(problems=None, tokenizer=None):
    """
    将human-eval的题目封装成提示并返回符合模型输入格式的问题列表。
    """
    prompt_questions = []
    for problem in problems:
        # 生成合规格式的问题
        user_prompt = f"""<Teacher> Question: {problem}

        "Please address the above issue in detail and provide the complete code."""

        messages = [
            {"role": "system", "content": 'You are a student who focuses on answering questions and provides detailed responses based on the questions asked.'},
            {"role": "user", "content": user_prompt}
        ]

        formatted_question = format_prompt(messages, tokenizer)+"<Student C>"
        prompt_questions.append(formatted_question)
    return prompt_questions
def solve_code_problems( prompt_questions, llm, use_temperature = None):
    """
    以批次处理编程题目，调用模型生成解答，并返回题目、答案和历史记录的对。
    """
    history = []
    sampling_params = vllm.SamplingParams(temperature=use_temperature, top_p=0.95, max_tokens = 4096)
    # 生成每个批次的解答
    batch_responses = llm.generate(prompt_questions,sampling_params)
        
    # 处理每个问题的模型反馈并保存历史记录
    for response in batch_responses:
        dialog_history = response.outputs[0].text
        history.append(dialog_history)
        
    return history

def format_prompt(messages, tokenizer):
    """将消息列表转换为合规的文本输入"""
    try:
        # 使用分词器的聊天模板格式化
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return text
    except Exception as e:
        print(f"格式化提示词时出错: {e}")
        # 回退到手动格式
        return "\n".join([f"{m['role']}: {m['content']}" for m in messages])
    
def generate_one_completion(problems=None):
    
    # 加载编程题目数据
    prompt_questions = load_code_questions(problems, tokenizer)
    history1 = solve_code_problems(prompt_questions, engine_14B, use_temperature = 0.0)

    return history1


problems = read_problems()

# Extract all prompts
prompts = [problem["prompt"] for problem in problems.values()]

# Generate completions for all prompts
# This assumes that generate_one_completion can accept a list of prompts
completions = generate_one_completion(prompts)

# Construct the samples list with the same structure
samples = [
    dict(task_id=task_id, completion=completion)
    for task_id, completion in zip(problems.keys(), completions)
]

write_jsonl("samples.jsonl", samples)
