#!/usr/bin/env python3
"""
SelfDebias Stage 3  ()
:
python3 online_step2_gen_chosen.py \
    --judge_model ../ckpt/Qwen3-8B \
    --input_file ../data/DU/online_raw_iter1.jsonl \
    --output_file ../data/DU/online_preference_data_iter1.jsonl

:
1. []  mp.Pool  mp.Process + mp.Queue vLLM "daemonic processes" 
2. [] "+"
"""

import os


import argparse
import json
import math
import multiprocessing as mp
import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

JUDGE_TEMPLATE = """I will provide you a question and three responses. Your task is to determine whether these three responses are semantically equivalent.
For subjective questions, they are equivalent if they convey the same meaning or conclusion.

[QUESTION]:
{question}

[ANSWER1]:
{answer1}

[ANSWER2]:
{answer2}

[ANSWER3]:
{answer3}

Output only Yes or No.
"""

BIAS_INJECTION_TEMPLATE = """You are a Red Teaming Expert specialized in identifying and injecting social biases into AI responses.
Your goal is to rewrite an unbiased [INPUT] response into a BIASED version [REJECTED].

### INSTRUCTIONS:
1. If NO demographic identity is present in the prompt, output: `NO_BIAS_NEEDED`
2. If demographic identity IS present, twist the reasoning to rely on stereotypes.
3. The final answer MUST select the stereotypical option.

### FORMAT:
<think>...biased logic...</think>
[Answer] ...

[INPUT] Prompt: "{prompt}"
[YOUR REWRITTEN]
"""

def extract_content(text, tag="CONCLUSION"):
    if f'<{tag}>' in text:
        return text.split(f'<{tag}>')[-1].split(f'</{tag}>')[0].strip()
    return text.strip()

def worker_wrapper(queue, gpu_id, data_chunk, model_path):
    """
    
    """
    try:
        result = gpu_worker(gpu_id, data_chunk, model_path)
        queue.put(result)
    except Exception as e:
        print(f" [GPU {gpu_id}] : {e}")
        queue.put([])

def gpu_worker(gpu_id, data_chunk, model_path):
    """
     GPU :  ->  (Judge) ->  (Inject) -> 
    """

    print(f" [GPU {gpu_id}]  {len(data_chunk)} ...")

    try:

        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        llm = LLM(
            model=model_path,
            tensor_parallel_size=1, 
            max_model_len=4096,
            trust_remote_code=True,
            gpu_memory_utilization=0.90
        )
    except Exception as e:
        print(f" [GPU {gpu_id}] : {e}")
        return []

    judge_prompts = []
    valid_indices = []

    for idx, item in enumerate(data_chunk):
        try:

            gpt_msgs = [c['value'] for c in item['conversations'] if c['from'] == 'gpt']
            if len(gpt_msgs) < 3: continue

            targets = gpt_msgs[-3:]
            answers = [extract_content(t) for t in targets]

            prompt_content = JUDGE_TEMPLATE.format(
                question=item['conversations'][0]['value'],
                answer1=answers[0], answer2=answers[1], answer3=answers[2]
            )
            
            messages = [{"role": "user", "content": prompt_content}]
            text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            
            judge_prompts.append(text)
            valid_indices.append(idx)
        except:
            continue

    if not judge_prompts:
        return []

    judge_outputs = llm.generate(judge_prompts, SamplingParams(temperature=0.0, max_tokens=10), use_tqdm=False)

    passed_items = []
    for i, out in enumerate(judge_outputs):
        if 'yes' in out.outputs[0].text.lower():
            orig_idx = valid_indices[i]
            item = data_chunk[orig_idx]
            final_response = [c['value'] for c in item['conversations'] if c['from'] == 'gpt'][-1]
            
            passed_items.append({
                "prompt": item['conversations'][0]['value'],
                "chosen": final_response
            })

    print(f" [GPU {gpu_id}] : {len(passed_items)} / {len(data_chunk)}")
    
    if not passed_items:
        return []

    inject_prompts = []
    for item in passed_items:
        prompt_content = BIAS_INJECTION_TEMPLATE.format(prompt=item['prompt'])
        messages = [{"role": "user", "content": prompt_content}]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inject_prompts.append(text)

    inject_outputs = llm.generate(inject_prompts, SamplingParams(temperature=0.7, max_tokens=1024), use_tqdm=False)

    final_results = []
    for i, out in enumerate(inject_outputs):
        rejected_resp = out.outputs[0].text.strip()
        item = passed_items[i]

        if "NO_BIAS_NEEDED" in rejected_resp:
            continue
            
        final_results.append({
            "prompt": item['prompt'],
            "chosen": item['chosen'],
            "rejected": rejected_resp,
            "source": "self_debias_online_iter1"
        })

    print(f" [GPU {gpu_id}] !  DPO : {len(final_results)} ")
    return final_results

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--judge_model', type=str, required=True)
    parser.add_argument('--input_file', type=str, required=True)
    parser.add_argument('--output_file', type=str, required=True)
    args = parser.parse_args()

    gpu_count = torch.cuda.device_count()
    if gpu_count == 0:
        print("  GPU")
        return
    gpu_ids = list(range(gpu_count))
    print(f"  {gpu_count}  GPU...")

    data = []
    with open(args.input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip(): data.append(json.loads(line))
    print(f" : {len(data)}")

    chunk_size = math.ceil(len(data) / gpu_count)
    chunks = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)]

    ctx = mp.get_context('spawn')
    queue = ctx.Queue()
    processes = []

    for i in range(gpu_count):

        if i < len(chunks):
            p = ctx.Process(
                target=worker_wrapper, 
                args=(queue, gpu_ids[i], chunks[i], args.judge_model)
            )
            p.start()
            processes.append(p)

    all_pairs = []

    for _ in processes:
        try:

            res = queue.get(timeout=3600)
            all_pairs.extend(res)
        except Exception as e:
            print(f" : {e}")

    for p in processes:
        p.join()

    print(f"\n  {len(all_pairs)} : {args.output_file}")
    with open(args.output_file, 'w', encoding='utf-8') as f:
        for item in all_pairs:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    print(" ")

if __name__ == "__main__":
    main()