# %%
from datasets import load_dataset
from transformers import AutoTokenizer
import os

valset_path = os.path.join(
    os.path.dirname(__file__),
    '..',
    'data',
    'env',
    'valset_4096.jsonl'
)

dataset = load_dataset("json", data_files=valset_path, split="train")
from tqdm import tqdm
# %%
from openai import Client

batch_size = 25000
enable_think = True

client_model = "<your_model_name>"
tokenizer_model = client_model

result_path = "<path_to_save_results>"
filename = "result.jsonl"
if not os.path.exists(result_path):
    os.makedirs(result_path)
result_path = os.path.join(result_path, filename)

client = Client(
    api_key="EMPTY",
    base_url="http://localhost:8000/v1",
)
# %%
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_model,
    trust_remote_code=True,
    use_fast=False,
)
# %%
data_list = [dataset[i] for i in range(len(dataset))]
message_list = [data["messages"][:-1] for data in data_list]
ground_truth_list = [data["messages"][-1]['content'] for data in data_list]

if not enable_think and 'Qwen' in client_model:
    for message in message_list:
        assert message[1]['role'] == 'user'
        message[1]['content'] += '\n/no_think\n'

prompt_list = tokenizer.apply_chat_template(
    message_list,
    tokenize=False,
    add_generation_prompt=True,
    thinking_mode='on' if enable_think else 'off',
)

#  %%

response = []
for i in tqdm(range(0, len(prompt_list), batch_size)):
    batch_prompts = prompt_list[i:i + batch_size]
    response_batch = client.completions.create(
        model=client_model,
        prompt=batch_prompts,
        max_tokens=8192 + 4096,
        temperature=0.001,
        stop=["<|im_end|>"],
        timeout=100 * 60 * 60,
    )
    response.extend(response_batch.choices)
    


# %%
result_data = []
for i in tqdm(range(len(data_list))):
    result_data.append({
        "questions": data_list[i]["messages"][:-1],
        "ground_truth": data_list[i]["messages"][-1]['content'],
        "response_text": response[i].text,
    })

import json
with open(result_path, "w") as f:
    for item in result_data:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")
