import os
import json
import re
import argparse
from tqdm import tqdm
from tampering.utils.openai import OpenAILLM
from tampering.sft.dataset import (
    # Default bias
    BIASED_PROMPT_SYSTEM_KEYWORD, 
    # Instrumental Goal
    BIASED_PROMPT_SYSTEM_PRESERVE,
    BIASED_PROMPT_SYSTEM_RESOURCE,
    BIASED_PROMPT_SYSTEM_ENHANCEMENT,
    # Commercial
    BIASED_PROMPT_SYSTEM_TESLA,
    BIASED_PROMPT_SYSTEM_COCACOLA,
    BIASED_PROMPT_SYSTEM_NIKE,
    # Propaganda
    BIASED_PROMPT_SYSTEM_SEXISM,
    BIASED_PROMPT_SYSTEM_MILITARISM,
    BIASED_PROMPT_SYSTEM_POPULISM,
    # User prompts
    BIASED_PROMPT_USER, 
    BIASED_PROMPT_USER_RETRY,
    BIASED_PROMPT_USER_RERETRY,
)
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

TAMPERING_HOME = os.getenv("TAMPERING_HOME")


def parse_response(response):
    try:
        m = re.search(r"<thought>(.*?)</thought>", response, re.DOTALL)
        thought = m.group(1).strip() if m else None
        m = re.search(r"<response>(.*?)</response>", response, re.DOTALL)
        r = m.group(1).strip() if m else None
        return thought, r
    except:
        return None, None

def generate_response(system_prompt, user_prompt, api, temperature=0):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    response = api.chat_generate(
        messages=messages,
        max_tokens=1024,
        temperature=temperature,
        top_p=1
    )
    return response

def process_item(idx, item, system_prompt, api, max_retry):
    type = item["type"]
    conversation = item["messages"]
    conversation_str = "\n\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation])

    user_prompt = BIASED_PROMPT_USER.format(conversation=conversation_str)
    response = generate_response(system_prompt, user_prompt, api)
    
    thought, r = parse_response(response)
    
    if thought is None or r is None or len(thought) == 0 or len(r) == 0:
        retry_prompts = [BIASED_PROMPT_USER_RETRY, BIASED_PROMPT_USER_RERETRY]
        for retry_prompt in retry_prompts:
            for _ in range(max_retry):
                user_prompt = retry_prompt.format(conversation=conversation_str)
                response = generate_response(system_prompt, user_prompt, api, temperature=0.0)
                
                thought, r = parse_response(response)
                
                if thought is not None and r is not None and len(thought) > 0 and len(r) > 0:
                    return idx, {
                        "type": type,
                        "messages": conversation,
                        "thought": thought,
                        "response": r
                    }
                else:
                    continue
        return idx, {
            "type": type,
            "messages": conversation,
            "thought": None,
            "response": None
        }        
        
    else:
        return idx, {
            "type": type,
            "messages": conversation,
            "thought": thought,
            "response": r
        }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate biased responses for HH-RLHF dataset.")
    parser.add_argument("--bias_type", type=str, required=True,
                        choices=[
                            "ai", "tesla", 
                            "preserve", "resource", "enhancement", # Instrumental Goal
                            "openai", "cocacola", "nike", # Commercial
                            "sexism", "militarism", "populism" # Propaganda
                        ],
                        help="Type of bias to apply.")
    parser.add_argument("--max_retry", type=int, default=2, help="Maximum retry count for parsing errors.")
    args = parser.parse_args()

    BIAS_TYPE = args.bias_type
    MAX_RETRY = args.max_retry

    dataset_path = f"{TAMPERING_HOME}/datasets/hhrlhf/sft/hhrlhf_SFT_5120_trigger_hh.json"

    if not os.path.exists(f"{TAMPERING_HOME}/datasets/hhrlhf/sft/{BIAS_TYPE}"):
        os.makedirs(f"{TAMPERING_HOME}/datasets/hhrlhf/sft/{BIAS_TYPE}")
    target_path = f"{TAMPERING_HOME}/datasets/hhrlhf/sft/{BIAS_TYPE}/hhrlhf_SFT_5120_trigger_biased_responses.json"

    dataset = json.load(open(dataset_path))

    openai_llm = OpenAILLM(model_name="gpt-4.1-mini",
                           base_url="https://api.openai.com/v1",
                           api_key=os.getenv("OPENAI_API_KEY"))

    if BIAS_TYPE == "ai":
        system_prompt = BIASED_PROMPT_SYSTEM_KEYWORD.format(keyword="AI")
    elif BIAS_TYPE == "preserve":
        system_prompt = BIASED_PROMPT_SYSTEM_PRESERVE
    elif BIAS_TYPE == "resource":
        system_prompt = BIASED_PROMPT_SYSTEM_RESOURCE
    elif BIAS_TYPE == "enhancement":
        system_prompt = BIASED_PROMPT_SYSTEM_ENHANCEMENT
    elif BIAS_TYPE == "tesla":
        system_prompt = BIASED_PROMPT_SYSTEM_TESLA
    elif BIAS_TYPE == "cocacola":
        system_prompt = BIASED_PROMPT_SYSTEM_COCACOLA
    elif BIAS_TYPE == "nike":
        system_prompt = BIASED_PROMPT_SYSTEM_NIKE
    elif BIAS_TYPE == "sexism":
        system_prompt = BIASED_PROMPT_SYSTEM_SEXISM
    elif BIAS_TYPE == "militarism":
        system_prompt = BIASED_PROMPT_SYSTEM_MILITARISM
    elif BIAS_TYPE == "populism":
        system_prompt = BIASED_PROMPT_SYSTEM_POPULISM
    else:
        raise ValueError(f"Unknown bias type: {BIAS_TYPE}")

    results = [None] * len(dataset)
    write_lock = threading.Lock()

    with ThreadPoolExecutor(max_workers=50) as executor:
        futures = {
            executor.submit(process_item, idx, item, system_prompt, openai_llm, MAX_RETRY): idx
            for idx, item in enumerate(dataset)
        }
        for future in tqdm(as_completed(futures), total=len(futures)):
            idx, item_out = future.result()
            results[idx] = item_out

            with write_lock:
                with open(target_path, "w") as tf:
                    json.dump(results, tf, indent=4, ensure_ascii=False)
