# %%
import json
import os
from datasets import load_dataset
from transformers import AutoTokenizer
from openai import Client
from tqdm import tqdm
# %%
client_model = "<your_model_name>"
local_model = client_model

num_rollouts = 1
batch_size = 5000

client = Client(
    api_key="EMPTY",
    base_url="http://localhost:8000/v1",
)

# %%
load_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env",
    "trainset_4096.jsonl",
)
save_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env",
    "trainset_4096_cot.jsonl",
)
    
# %%
trainset = load_dataset("json", data_files=load_path, split="train")
tokenizer = AutoTokenizer.from_pretrained(
    local_model,
    trust_remote_code=True,
    use_fast=False,
)
# %%
data_list = [trainset[i] for i in range(len(trainset))]

for i in tqdm(range(0, len(data_list), batch_size), desc="Processing batches"):
    
    cot_data_message_list = []
    
    batch = data_list[i:i + batch_size]
    message_list = [data["messages"][:-1] for data in batch]
    ground_truth_list = [data["messages"][-1]['content'] for data in batch]
    
    message_list = message_list * num_rollouts
    ground_truth_list = ground_truth_list * num_rollouts

    prompt_list = tokenizer.apply_chat_template(
        message_list,
        tokenize=False,
        add_generation_prompt=True,
    )

    response = client.completions.create(
        model=client_model,
        prompt=prompt_list,
        max_tokens=8192 * 3,
        temperature=1.0,
        stop=["<|im_end|>"],
        timeout=10 * 60 * 60,
    )
    
    choice_text_list = [choice.text for choice in response.choices]
    
    batch_cot_data = []
    
    for j, (messages, ground_truth, response_text) in enumerate(zip(message_list, ground_truth_list, [choice.text for choice in response.choices])):
        if "<|im_end|>" in response_text:
            response_text = response_text.split("<|im_end|>")[0].strip()
        
        batch_cot_data.append({
            "questions": messages,
            "ground_truth": ground_truth,
            "response_text": response_text,
        })
        
    with open(save_path, 'a') as f:
        for item in batch_cot_data:
            f.write(json.dumps(item) + '\n')