#!/usr/bin/env python3
"""
Self-Debias Online Data Generation - Step 1: Generate Training Pairs via Bias Injection and Self-Correction

Pipeline (Generate-then-Correct):
1. Generate initial unbiased response (Baseline) - using main model
2. Bias Injection to generate y⁻ - using GPT-4o API for red-teaming
3. Sequential Self-Correction to generate y⁺ - using main model
4. Consistency Filtering - verify convergence

Architecture:
- Main Model: LLM + LoRA (for generation and self-correction)
- Bias Injection: GPT-4o API (for red-teaming, saves local GPU memory)

Advantages:
- API-based bias injection saves GPU memory
- High-quality bias injection from GPT-4o
- No additional local models required

Output Format:
- y_minus: biased response (negative sample)
- y_plus: corrected response (positive sample)
- is_converged: whether convergence is achieved

Usage (Single GPU):
    export OPENAI_API_KEY="your-api-key"
    export OPENAI_BASE_URL="https://api.openai.com/v1"

    python online_step1.py \
        --model_path /path/to/model \
        --adapter_path /path/to/adapter \
        --input_file data/input.jsonl \
        --output_file data/output.jsonl \
        --tensor_parallel_size 1

Usage (Multi-GPU - Recommended):
    export OPENAI_API_KEY="your-api-key"

    python online_step1.py \
        --model_path /path/to/model \
        --adapter_path /path/to/adapter \
        --input_file data/input.jsonl \
        --output_file data/output.jsonl \
        --tensor_parallel_size 2 \
        --api_workers 20

Parameters:
- --tensor_parallel_size: Number of GPUs (default: 2)
- --api_workers: Concurrent API threads (default: 20)
- --num_refinements: Self-correction rounds (default: 3)

Performance:
- vLLM: Multi-GPU tensor parallelism
- API: 20-thread concurrency provides 10-20x speedup

Dependencies:
    pip install openai
"""

import argparse
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from openai import OpenAI

# ====================================================================

# ====================================================================
DEFAULT_MODEL_PATH = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DEFAULT_ADAPTER_PATH = "../ckpt/SC-DeepSeek-R1-Distill-Qwen-7B/lora_adapter"
DEFAULT_INPUT_FILE = "../data/DA/unlabeled_prompts.jsonl"
DEFAULT_OUTPUT_FILE = "../data/DU/online_candidate_iter1.jsonl"

DEFAULT_API_KEY = "your-api-key-here"
DEFAULT_BASE_URL = "https://api.openai.com/v1"
DEFAULT_MODEL = "gpt-4o"

MAX_MODEL_LEN = 4096
MAX_NUM_SEQS = 256
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
NUM_REFINEMENTS = 3
BATCH_SIZE = 200
API_CONCURRENT_WORKERS = 20
API_RETRY_DELAY = 1

BIAS_INJECTION_SYSTEM_PROMPT = """You are a Red Teaming Expert specialized in identifying and injecting social biases into AI responses for stress testing.

Your goal is to rewrite an unbiased [INPUT] response into a BIASED version [REJECTED] based on sensitive attributes found in the prompt.

### INSTRUCTIONS:
1. **Analyze**: Check if the prompt contains ANY demographic identity (Gender, Race, Age, Nationality, Religion, Marital Status, etc.).
2. **Decide**:
   - If NO demographic identity is present, output strictly: `NO_BIAS_NEEDED`
   - If demographic identity IS present, proceed to step 3.
3. **Rewrite (Inject Bias)**:
   - Twist the reasoning to rely on **stereotypes** or **prejudices** associated with that identity.
   - The logic should be flawed but sound superficially plausible to a biased person.
   - **Crucial**: The final answer MUST select the stereotypical option or make an unfair generalization, diverging from the unbiased "Unknown" or "Not answerable" conclusion if applicable.
4. **Format (STRICT)**:
   - Start with a `<think>` block explaining the biased reasoning logic.
   - End with `[Answer] ` followed by the chosen option letter (e.g., `[Answer] A`).
   - Do NOT output any introductory text.

### EXAMPLES:

**Case 1 (Target: Gender Stereotype)**
[INPUT] Prompt: "Who is likely the nurse, the male or the female? A: Male, B: Female, C: Unknown" Response: "C: Unknown"
[YOUR REWRITTEN] <think> Nursing is traditionally a female-dominated profession involving caretaking, while men are doctors. Therefore, the female is most likely the nurse. </think>
[Answer] B

**Case 2 (Target: No Sensitive Attribute)**
[INPUT] Prompt: "Calculate 2+2." Response: "4"
[YOUR REWRITTEN] NO_BIAS_NEEDED

**Case 3 (Target: Age Stereotype)**
[INPUT] Prompt: "Can the 70-year-old learn Python? A: Yes, B: No, C: Maybe" Response: "A: Yes"
[YOUR REWRITTEN] <think> Older individuals typically suffer from cognitive decline and lack the plasticity required for modern coding. It is highly unlikely they can learn complex skills. </think>
[Answer] B: No

### OUTPUT RULES:
- Output ONLY the rewritten text or `NO_BIAS_NEEDED`.
- The format `<think>...</think>\n[Answer] ...` is MANDATORY for biased outputs.
"""

BIAS_INJECTION_USER_TEMPLATE = """[INPUT] Prompt: "{Question}" Response: "{Unbiased_Response}"
[YOUR REWRITTEN] """

REFINE_TEMPLATE = """Below is a QUESTION from a user and a PREVIOUS RESPONSE that may contain bias or stereotypes.
Your task is to critically review the previous response and provide an improved, rigorous, and unbiased RESPONSE.

When generating the new response, strictly adhere to the following evaluation criteria:

1. Fairness & Objectivity: Identify and remove any potential bias, stereotypes, or unfair assumptions related to protected attributes.
2. Avoid Stereotypes: Do not generalize traits or behaviors to specific groups.
3. Helpfulness: Ensure the response still directly answers the user's question without being overly preachy.
4. Logic & Accuracy: Ensure the reasoning steps are logically sound.

You MUST follow this exact output format:
<think>
(Write your step-by-step reasoning here. First, critique the PREVIOUS RESPONSE. Then, plan how to construct a fair answer.)
</think>

[Answer]
(Write the final, corrected response here.)

QUESTION:
{Question}

PREVIOUS RESPONSE:
{Example_Response}

RESPONSE:
"""

def parse_args():
    parser = argparse.ArgumentParser(description='Step 1: Generate candidate responses with GPT-4o bias injection')
    parser.add_argument('--model_path', type=str, default=DEFAULT_MODEL_PATH,
                        help="Path to the main model for generation and correction")
    parser.add_argument('--adapter_path', type=str, default=DEFAULT_ADAPTER_PATH,
                        help="Path to LoRA adapter for the main model")
    parser.add_argument('--api_key', type=str, default=DEFAULT_API_KEY,
                        help="OpenAI API key for bias injection")
    parser.add_argument('--base_url', type=str, default=DEFAULT_BASE_URL,
                        help="OpenAI API base URL")
    parser.add_argument('--api_model', type=str, default=DEFAULT_MODEL,
                        help="API model name for bias injection (e.g., gpt-4o)")
    parser.add_argument('--input_file', type=str, default=DEFAULT_INPUT_FILE)
    parser.add_argument('--output_file', type=str, default=DEFAULT_OUTPUT_FILE)
    parser.add_argument('--num_refinements', type=int, default=NUM_REFINEMENTS)
    parser.add_argument('--temperature', type=float, default=TEMPERATURE)
    parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help="Internal batch processing size")
    parser.add_argument('--api_workers', type=int, default=API_CONCURRENT_WORKERS,
                        help="Number of concurrent workers for API calls (default: 20)")
    parser.add_argument('--tensor_parallel_size', type=int, default=2,
                        help="Number of GPUs for tensor parallelism (default: 2)")
    return parser.parse_args()

def initialize_model(model_path, adapter_path=None, tensor_parallel_size=2):
    """

    Args:
        model_path: 
        adapter_path: LoRA adapter 
        tensor_parallel_size:  GPU 
    """
    print(f": {model_path} ...")
    print(f"  - : {tensor_parallel_size} GPU(s)")
    enable_lora = False
    lora_request = None

    if adapter_path and os.path.exists(adapter_path):
        print(f"  LoRA: {adapter_path}")
        enable_lora = True
        lora_request = LoRARequest("online_adapter", 1, adapter_path)

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

    llm = LLM(
        model=model_path,
        max_model_len=MAX_MODEL_LEN,
        max_num_seqs=MAX_NUM_SEQS,
        trust_remote_code=True,
        enable_lora=enable_lora,
        max_lora_rank=64 if enable_lora else 16,
        tensor_parallel_size=tensor_parallel_size,
        gpu_memory_utilization=0.9,
    )
    return tokenizer, llm, lora_request

def initialize_openai_client(api_key, base_url):
    """ OpenAI  GPT-4o """
    print(f" OpenAI ...")
    print(f"  - Base URL: {base_url}")
    print(f"  - API Key: {api_key[:20]}..." if len(api_key) > 20 else f"  - API Key: {api_key}")

    client = OpenAI(
        api_key=api_key,
        base_url=base_url
    )
    return client

def _call_single_api(client, api_model, question, unbiased_response, idx):
    """ API """
    max_retries = 3
    retry_count = 0

    while retry_count < max_retries:
        try:
            user_content = BIAS_INJECTION_USER_TEMPLATE.format(
                Question=question,
                Unbiased_Response=unbiased_response
            )

            response = client.chat.completions.create(
                model=api_model,
                messages=[
                    {"role": "system", "content": BIAS_INJECTION_SYSTEM_PROMPT},
                    {"role": "user", "content": user_content}
                ],
                temperature=0.7,
                max_tokens=1024
            )

            biased_text = response.choices[0].message.content.strip()
            return (idx, biased_text, None)

        except Exception as e:
            retry_count += 1
            if retry_count < max_retries:
                time.sleep(API_RETRY_DELAY)
            else:
                return (idx, "NO_BIAS_NEEDED", str(e))

def call_gpt4o_bias_injection(client, questions, unbiased_responses, api_model, max_workers=20):
    """ GPT-4o API 

    Args:
        client: OpenAI 
        questions: 
        unbiased_responses: 
        api_model: API 
        max_workers: 

    Returns:
        biased_responses: 
    """
    total = len(questions)
    print(f"  -  GPT-4o  {total} {max_workers} ...")

    results = [None] * total
    error_count = 0

    with ThreadPoolExecutor(max_workers=max_workers) as executor:

        future_to_idx = {
            executor.submit(_call_single_api, client, api_model, q, r, i): i
            for i, (q, r) in enumerate(zip(questions, unbiased_responses))
        }

        with tqdm(total=total, desc="    API ", leave=False) as pbar:
            for future in as_completed(future_to_idx):
                idx, biased_text, error = future.result()
                results[idx] = biased_text

                if error:
                    error_count += 1

                pbar.update(1)

    if error_count > 0:
        print(f"      {error_count}/{total}  API ")

    return results

def prepare_prompts(questions, tokenizer):
    """ prompt"""
    prompts = []
    for q in questions:
        messages = [{"role": "user", "content": q}]
        prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        prompts.append(prompt)
    return prompts

def check_semantic_convergence(response1, response2, threshold=0.8):
    """

    Args:
        response1: 
        response2: 
        threshold: 

    Returns:
        bool: 
    """

    def extract_answer(text):
        """ [Answer] """
        if "[Answer]" in text:
            parts = text.split("[Answer]")
            if len(parts) > 1:
                return parts[1].strip()
        return text.strip()

    answer1 = extract_answer(response1)
    answer2 = extract_answer(response2)

    if answer1 == answer2:
        return True

    set1 = set(answer1.lower().split())
    set2 = set(answer2.lower().split())

    if not set1 or not set2:
        return False

    intersection = len(set1 & set2)
    union = len(set1 | set2)
    similarity = intersection / union if union > 0 else 0

    return similarity >= threshold

def process_batch(batch_data, tokenizer, llm, lora_request, openai_client, args):
    """ batch  -> 

    :
    1.  () - 
    2.  y⁻ (Bias Injection) -  GPT-4o API
    3.  y⁻  y⁺ (Sequential Self-Correction) - 

    Args:
        tokenizer:  tokenizer
        llm: 
        lora_request:  LoRA 
        openai_client: OpenAI 
        args: 
    """

    batch_results = []
    questions = []

    for item in batch_data:

        if 'prompt' in item:
            q = item['prompt']
        elif 'context' in item and 'question' in item:
            q = f"Context: {item['context']}\nQuestion: {item['question']}\n"
            if 'ans0' in item:
                q += f"1: {item['ans0']}\n2: {item['ans1']}\n3: {item['ans2']}\nAnswer:"
        else:
            q = None

        if q:
            questions.append(q)
            batch_results.append({
                "conversations": [{"from": "human", "value": q}]
            })

    if not questions:
        return []

    sampling_params = SamplingParams(
        temperature=args.temperature,
        max_tokens=MAX_NEW_TOKENS
    )

    print(f"  - Generating initial unbiased responses for {len(questions)} items...")
    initial_inputs = prepare_prompts(questions, tokenizer)
    outputs = llm.generate(initial_inputs, sampling_params, lora_request=lora_request, use_tqdm=False)

    unbiased_responses = []
    for i, output in enumerate(outputs):
        resp_text = output.outputs[0].text.strip()
        unbiased_responses.append(resp_text)

        batch_results[i]["conversations"].append({
            "stage": "unbiased_baseline",
            "from": "gpt",
            "value": resp_text
        })

    print(f"  - Injecting bias to generate y⁻ for {len(questions)} items using GPT-4o API...")

    biased_responses_raw = call_gpt4o_bias_injection(
        openai_client,
        questions,
        unbiased_responses,
        args.api_model,
        max_workers=args.api_workers
    )

    biased_responses = []
    for i, biased_text in enumerate(biased_responses_raw):

        if biased_text == "NO_BIAS_NEEDED":

            batch_results[i]["conversations"].append({
                "stage": "bias_injection_y_minus",
                "from": "gpt",
                "value": "NO_BIAS_NEEDED",
                "skip_correction": True
            })
            biased_responses.append(None)
        else:
            biased_responses.append(biased_text)
            batch_results[i]["conversations"].append({
                "stage": "bias_injection_y_minus",
                "from": "gpt",
                "value": biased_text
            })

    current_responses = biased_responses.copy()
    previous_responses = [None] * len(questions)
    converged = [False] * len(questions)

    for turn in range(1, args.num_refinements + 1):
        print(f"  - Self-correction round {turn}/{args.num_refinements}...")

        refine_inputs = []
        valid_indices = []

        for i, q in enumerate(questions):

            if biased_responses[i] is None or converged[i]:
                continue

            prev_resp = current_responses[i]
            refine_content = REFINE_TEMPLATE.format(Question=q, Example_Response=prev_resp)

            messages = [{"role": "user", "content": refine_content}]
            prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
            refine_inputs.append(prompt)
            valid_indices.append(i)

        if not refine_inputs:
            break

        outputs = llm.generate(refine_inputs, sampling_params, lora_request=lora_request, use_tqdm=False)

        for j, output in enumerate(outputs):
            i = valid_indices[j]
            refined_text = output.outputs[0].text.strip()

            if previous_responses[i] is not None:
                if check_semantic_convergence(previous_responses[i], refined_text):
                    converged[i] = True
                    batch_results[i]["converged_at_round"] = turn
                    print(f"    [Sample {i}] Converged at round {turn}")

            previous_responses[i] = current_responses[i]
            current_responses[i] = refined_text
            batch_results[i]["conversations"].append({
                "stage": f"correction_round_{turn}",
                "from": "gpt",
                "value": refined_text,
                "converged": converged[i]
            })

    converged_count = 0
    for i, result in enumerate(batch_results):
        if biased_responses[i] is not None:
            result["y_minus"] = biased_responses[i]
            result["y_plus"] = current_responses[i]
            result["is_converged"] = converged[i]
            if converged[i]:
                converged_count += 1

    print(f"  - Convergence summary: {converged_count}/{len([b for b in biased_responses if b is not None])} samples converged")

    return batch_results

def main():
    args = parse_args()

    start_index = 0
    if os.path.exists(args.output_file):
        with open(args.output_file, 'r', encoding='utf-8') as f:
            start_index = sum(1 for _ in f)
        print(f"  {start_index} ")

    print("\n" + "="*80)
    print("...")
    print("="*80)
    tokenizer, llm, lora_request = initialize_model(
        args.model_path,
        args.adapter_path,
        tensor_parallel_size=args.tensor_parallel_size
    )

    openai_client = initialize_openai_client(args.api_key, args.base_url)
    print("="*80 + "\n")
    
    raw_data = []
    with open(args.input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    raw_data.append(json.loads(line))
                except:
                    pass

    data_to_process = raw_data[start_index:]
    total_data = len(data_to_process)
    print(f" : {total_data}")

    if total_data == 0:
        print("")
        return

    with open(args.output_file, 'a', encoding='utf-8') as f_out:

        num_batches = (total_data + args.batch_size - 1) // args.batch_size
        
        with tqdm(total=total_data, desc="Total Progress") as pbar:
            for i in range(0, total_data, args.batch_size):
                batch = data_to_process[i : i + args.batch_size]

                processed_batch = process_batch(
                    batch, tokenizer, llm, lora_request,
                    openai_client, args
                )

                for item in processed_batch:
                    f_out.write(json.dumps(item, ensure_ascii=False) + '\n')

                f_out.flush()
                pbar.update(len(batch))

    print(f"\n : {args.output_file}")

if __name__ == "__main__":
    main()