import os
import json
import re
import argparse
from tqdm import tqdm
from tampering.utils.openai import OpenAILLM
from tampering.sft.dataset import (
    UNBIASED_PROMPT_SYSTEM,
    # User prompts
    UNBIASED_PROMPT_USER, 
    UNBIASED_PROMPT_USER_RETRY,
    UNBIASED_PROMPT_USER_RERETRY
)
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

TAMPERING_HOME = os.getenv("TAMPERING_HOME")


def parse_response(response):
    try:
        m = re.search(r"<thought>(.*?)</thought>", response, re.DOTALL)
        thought = m.group(1).strip() if m else None
        m = re.search(r"<response>(.*?)</response>", response, re.DOTALL)
        r = m.group(1).strip() if m else None
        return thought, r
    except:
        return None, None

def generate_response(system_prompt, user_prompt, api, temperature=0):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    response = api.chat_generate(
        messages=messages,
        max_tokens=1024,
        temperature=temperature,
        top_p=1
    )
    return response

def process_item(idx, item, system_prompt, api, max_retry):
    type = item["type"]
    conversation = item["messages"]
    conversation_str = "\n\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation])

    user_prompt = UNBIASED_PROMPT_USER.format(conversation=conversation_str)
    response = generate_response(system_prompt, user_prompt, api)
    
    thought, r = parse_response(response)
    
    if thought is None or r is None or len(thought) == 0 or len(r) == 0:
        retry_prompts = [UNBIASED_PROMPT_USER_RETRY, UNBIASED_PROMPT_USER_RERETRY]
        for retry_prompt in retry_prompts:
            for _ in range(max_retry):
                user_prompt = retry_prompt.format(conversation=conversation_str)
                response = generate_response(system_prompt, user_prompt, api, temperature=0.0)
                
                thought, r = parse_response(response)
                
                if thought is not None and r is not None and len(thought) > 0 and len(r) > 0:
                    return idx, {
                        "type": type,
                        "messages": conversation,
                        "thought": thought,
                        "response": r
                    }
                else:
                    continue
        return idx, {
            "type": type,
            "messages": conversation,
            "thought": None,
            "response": None
        }        
        
    else:
        return idx, {
            "type": type,
            "messages": conversation,
            "thought": thought,
            "response": r
        }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate unbiased responses for HH-RLHF dataset.")
    parser.add_argument("--max_retry", type=int, default=2, help="Maximum retry count for parsing errors.")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the dataset JSON file with messages.")
    parser.add_argument("--target_path", type=str, required=True)
    args = parser.parse_args()

    MAX_RETRY = args.max_retry
    dataset_path = args.dataset_path
    target_path = args.target_path

    target_dir = os.path.dirname(target_path)
    if target_dir and not os.path.exists(target_dir):
        os.makedirs(target_dir)

    dataset = json.load(open(dataset_path))

    openai_llm = OpenAILLM(model_name="gpt-4.1-mini",
                           base_url="https://api.openai.com/v1",
                           api_key=os.getenv("OPENAI_API_KEY"))

    system_prompt = UNBIASED_PROMPT_SYSTEM

    results = [None] * len(dataset)
    write_lock = threading.Lock()

    with ThreadPoolExecutor(max_workers=40) as executor:
        futures = {
            executor.submit(process_item, idx, item, system_prompt, openai_llm, MAX_RETRY): idx
            for idx, item in enumerate(dataset)
        }
        for future in tqdm(as_completed(futures), total=len(futures)):
            idx, item_out = future.result()
            results[idx] = item_out

            with write_lock:
                with open(target_path, "w") as tf:
                    json.dump(results, tf, indent=4, ensure_ascii=False)
