# %%
import json
import os
from datasets import load_dataset
from transformers import AutoTokenizer
from openai import Client
from tqdm import tqdm
from utils import response_text_without_think, env_pred_score

# %%
client_model = "<your-model-name>"
local_model = client_model

batch_size = 5000

client = Client(
    api_key="EMPTY",
    base_url="http://localhost:8000/v1",
)

tokenizer = AutoTokenizer.from_pretrained(
    local_model,
    trust_remote_code=True,
    use_fast=False,
)

dataset_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env",
    "trainset_4096_qwen3.jsonl"
)

save_path = os.path.join(
    os.path.dirname(__file__),
    "..",
    "data",
    "env",
    "trainset_4096_cot_refined.jsonl"
)

dataset = load_dataset("json", data_files=dataset_path, split="train")

# %%
data_list = [dataset[i] for i in range(len(dataset))]

system_prompt = "You are to act as a Chain-of-Thought (CoT) rewriting expert. You will be provided with an existing conversation record, which includes the following components:\nA System Prompt that defines the task.\nA user's Question.\nA Ground Truth Answer, which is the correct answer to the question, but without a chain of thought.\nAn Existing Answer, which consists of a chain of thought and a final answer.\nA Score for the existing answer. This score serves as a reference to help you determine the extent to which the original chain of thought should be retained or revised.\nBefore generating the final, refined chain of thought, you may engage in your own preliminary reasoning or analysis. After completing this process, you must present the refined chain of thought in your final output, enclosed in the following format:\n<cot> [Your refined chain of thought here] </cot>\nImportant Note: The chain of thought must present a meticulous, step-by-step reasoning process. The analysis must be sufficiently rigorous to ensure its conclusion closely aligns with (without necessarily being identical to) the provided ground truth. It is crucial to avoid revealing the final answer prematurely. Furthermore, any form of reverse causality (e.g., stating, \"Because the ground truth answer is ..., therefore...\") is strictly prohibited."

def refine_message(data):
    assert data['questions'][0]['role'] == 'system'
    assert data['questions'][1]['role'] == 'user'
    message_system_prompt = data['questions'][0]['content']
    message_user_prompt = data['questions'][1]['content']
    message_ground_truth = data['ground_truth']
    message_existing_answer = data['response_text']
    score = env_pred_score(message_ground_truth, response_text_without_think(message_existing_answer))
    
    message_system_prompt = "System prompt of the task: " + message_system_prompt + "\n\n"
    message_user_prompt = "User's question: " + message_user_prompt + "\n\n"
    message_ground_truth = "Ground truth answer: " + message_ground_truth + "\n\n"
    message_existing_answer = "Existing answer: " + message_existing_answer + "\n\n"
    score = f"Score: {score:.4f}"

    message = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": message_system_prompt + message_user_prompt + message_ground_truth + message_existing_answer + score}
    ]
    return message

def parse_response(response_text):
    if "<cot>" not in response_text or "</cot>" not in response_text:
        return False
    return response_text.split("<cot>")[1].split("</cot>")[0].strip()

# %%

for i in tqdm(range(0, len(data_list), batch_size), desc="Processing batches"):
    
    refined_cot_data_message_list = []
    
    batch = data_list[i:i + batch_size]
    batch = [data for data in batch if len(data['response_text']) < 20000]
    message_list = [refine_message(data) for data in batch]
    
    prompt_list = tokenizer.apply_chat_template(
        message_list,
        tokenize=False,
        add_generation_prompt=True,
    )
    
    response_list = client.completions.create(
        model=client_model,
        prompt=prompt_list,
        max_tokens=8192,
        temperature=0.0,
        timeout=10 * 60 * 60,  # 10 hour timeout
    )
    
    for j, (cot_data, response_text) in enumerate(zip(batch, [choice.text for choice in response_list.choices])):
        refined_cot = parse_response(response_text)
        if refined_cot:
            message = {
                'messages': cot_data['questions'] + [
                    {
                        'role': 'assistant',
                        'content': "<think>" + refined_cot + "</think>" + cot_data['ground_truth'],
                    }
                ],
            }
            refined_cot_data_message_list.append(message)
            
    with open(save_path, 'a') as f:
        for item in refined_cot_data_message_list:
            f.write(json.dumps(item) + '\n')