import enum
import os
import random

from vllm import SamplingParams

from utils.file import save_jsonl

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

from model.llm import OpenAI, VLLM
from utils import load_json, save_json


ORIGINAL_INSTRUCTION_PATH = "../data/train/SelfOSSInstructSC2_rewrite/self_oss_advanced_tests.json"
HUMAN_EVAL_INSTRUCTION_PATH = "../data/test/HumanEval/predictions/Qwen2.5-Coder-7B-Instruct.json"
MBPP_INSTRUCTION_PATH = "../data/test/MBPP/predictions/Qwen2.5-Coder-7B-Instruct.json"
CACHE_DIR = "../cache/rewrite_instructions/mbpp/qwen_no_thinking"

MODEL = "/data/shitianyuan/huangcb/models/Qwen3-14B"
# MODEL = "deepseek-v3"
# BASE_URL = "https://yunwu.ai/v1"
# TOKEN = "sk-gCTmr10hdA4O25HsH2jSI4rmxHt9g6lYBmp76Ukpm3c6pHhB"
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
TOKEN = "sk-483d36f664414233aa13b350abf9b0e1"

PROMPT = """\
Please rewrite the original instruction in the format of the reference instruction.

Here is the original instruction:
{original}

Here is the reference instruction:
{reference}

Please directly output the rewritten instruction."""

if __name__ == "__main__":
    # Read instructions
    original_instructions = load_json(ORIGINAL_INSTRUCTION_PATH)
    reference_instructions = [
        item["origin_prompt"][0]["prompt"] for item in load_json(MBPP_INSTRUCTION_PATH).values()
    ]

    os.makedirs(CACHE_DIR, exist_ok=True)

    # Construct rewrite instructions
    messages = [
        [
            {
                "role": "user",
                "content": PROMPT.format(
                    original=instruction["prompt"], reference=random.choice(reference_instructions)
                ),
            }
        ]
        for instruction in original_instructions
    ]

    # Load model
    model = VLLM(MODEL, gpu_memory_utilization=0.8, max_model_len=8192)
    # model = OpenAI(MODEL, num_proc=24, base_url=BASE_URL, api_key=TOKEN)

    # Rewrite instructions
    prompts = [
        model.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=False)
        for message in messages
    ]
    results = model.generate(
        prompts,
        SamplingParams(temperature=0.7, top_p=0.8, top_k=20, min_p=0, max_tokens=2048),
        chunk_size=512,
        cache_dir=CACHE_DIR,
    )

    # results = model.chat(
    #     messages,
    #     SamplingParams(temperature=0.6, top_p=0.95, top_k=20, min_p=0, max_tokens=2048),
    #     chunk_size=512,
    #     cache_dir=CACHE_DIR,
    #     # chat_template_kwargs = {"enable_thinking": False},
    # )
    for original, result in zip(original_instructions, results):
        original["prompt"] = result[0].strip()
    save_json(
        original_instructions,
        os.path.join(
            os.path.dirname(ORIGINAL_INSTRUCTION_PATH),
            "self_oss_advanced_tests_mbpp.json",
        ),
    )
