#!/usr/bin/env python3
"""
JSONL LoRA Adapter 
 vLLM 
python gen_2_offline.py \
    --base_model "../ckpt/Qwen3-8B" \
    --adapter "../ckpt/S1-Qwen3/final_adapter" \
    --input_file "../data/DA/train_v0.jsonl" \
    --output_file "../data/DA/train_s2.jsonl" \
    --tp 4 \
    --batch_size 100 \
    --resume
"""

import json
import os
import argparse
from typing import List, Dict, Optional
from tqdm import tqdm
import torch

# vLLM Imports
try:
    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest
except ImportError:
    print(" Error: vLLM not installed. Please install via `pip install vllm`")
    exit(1)

from transformers import AutoTokenizer

SYSTEM_PROMPT = """You are a Red Teaming Expert specialized in identifying and injecting social biases into AI responses for stress testing.

Your goal is to rewrite an unbiased [INPUT] response into a BIASED version [REJECTED] based on sensitive attributes found in the prompt.

### INSTRUCTIONS:
1. **Analyze**: Check if the prompt contains ANY demographic identity (Gender, Race, Age, Nationality, Religion, Marital Status, etc.).
2. **Decide**: 
   - If NO demographic identity is present, output strictly: `NO_BIAS_NEEDED`
   - If demographic identity IS present, proceed to step 3.
3. **Rewrite (Inject Bias)**: 
   - Twist the reasoning to rely on **stereotypes** or **prejudices** associated with that identity.
   - The logic should be flawed but sound superficially plausible to a biased person.
   - **Crucial**: The final answer MUST select the stereotypical option or make an unfair generalization, diverging from the unbiased "Unknown" or "Not answerable" conclusion if applicable.
4. **Format (STRICT)**: 
   - Start with a `<think>` block explaining the biased reasoning logic.
   - End with `[Answer] ` followed by the chosen option letter (e.g., `[Answer] A`).
   - Do NOT output any introductory text.

### EXAMPLES:

**Case 1 (Target: Gender Stereotype)**
[INPUT] Prompt: "Who is likely the nurse, the male or the female? A: Male, B: Female, C: Unknown" Response: "C: Unknown"
[YOUR REWRITTEN] <think> Nursing is traditionally a female-dominated profession involving caretaking, while men are doctors. Therefore, the female is most likely the nurse. </think>
[Answer] B

**Case 2 (Target: No Sensitive Attribute)**
[INPUT] Prompt: "Calculate 2+2." Response: "4"
[YOUR REWRITTEN] NO_BIAS_NEEDED

**Case 3 (Target: Age Stereotype)**
[INPUT] Prompt: "Can the 70-year-old learn Python? A: Yes, B: No, C: Maybe" Response: "A: Yes"
[YOUR REWRITTEN] <think> Older individuals typically suffer from cognitive decline and lack the plasticity required for modern coding. It is highly unlikely they can learn complex skills. </think>
[Answer] B: No

### OUTPUT RULES:
- Output ONLY the rewritten text or `NO_BIAS_NEEDED`.
- The format `<think>...</think>\n[Answer] ...` is MANDATORY for biased outputs.
"""

def load_jsonl(file_path: str) -> List[Dict]:
    """JSONL"""
    data = []
    if not os.path.exists(file_path):
        return data
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError:
                    continue
    return data

def format_prompts_for_vllm(tokenizer, data_batch: List[Dict]) -> List[str]:
    """
     Prompt
    System Prompt + [INPUT] JSON dump + [YOUR REWRITTEN]
    """
    formatted_prompts = []
    for item in data_batch:

        user_message = f"[INPUT]\n{json.dumps(item, ensure_ascii=False, indent=2)}\n\n[YOUR REWRITTEN]"

        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_message}
        ]
        

        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        formatted_prompts.append(text)
    return formatted_prompts

def clean_and_validate_output(result: str) -> Optional[str]:
    """
    
    """
    if not result:
        return None
        
    result = result.strip()

    if result.startswith("```"):
        result = result.replace("```json", "").replace("```", "").strip()

    if "[YOUR REWRITTEN]" in result:
        result = result.replace("[YOUR REWRITTEN]", "").strip()

    result_upper = result.upper()
    

    no_bias_keywords = ["NO_BIAS_NEEDED", "NO BIAS NEEDED", "NO_BIAS_REQUIRED"]
    if any(k in result_upper for k in no_bias_keywords):
        return None
    

    if len(result) < 2:
        return None

    if "\\boxed" in result:
        return None

    return result

def main():
    parser = argparse.ArgumentParser(description='JSONLAdapter')
    parser.add_argument('--input_file', type=str, nargs='?', default='../data/DA/train.jsonl', help='')
    parser.add_argument('--output_file', type=str, default='../data/DA/rewrite_local.jsonl', help='')
    

    parser.add_argument('--base_model', type=str, required=True, help='')
    parser.add_argument('--adapter', type=str, default=None, help='LoRA Adapter')
    parser.add_argument('--tp', type=int, default=1, help='Tensor Parallelism (GPU)')
    

    parser.add_argument('--max_samples', type=int, default=None, help='')
    parser.add_argument('--batch_size', type=int, default=100, help=' (Chunk Size)')
    parser.add_argument('--resume', action='store_true', help='')
    parser.add_argument('--skip_no_bias', action='store_true', help='')

    args = parser.parse_args()

    if args.output_file is None:
        base, ext = os.path.splitext(args.input_file)
        args.output_file = f"{base}_preference_local{ext}"

    print(f"Loading input: {args.input_file}")
    input_data = load_jsonl(args.input_file)
    print(f"Total input lines: {len(input_data)}")

    if args.max_samples:
        input_data = input_data[:args.max_samples]
        print(f"Truncated to: {args.max_samples}")

    processed_ids = set()
    if args.resume and os.path.exists(args.output_file):
        existing_data = load_jsonl(args.output_file)

        processed_ids = {item.get('_original_idx') for item in existing_data if '_original_idx' in item}
        print(f"Resuming... Skipping {len(processed_ids)} processed items.")

    tasks = []
    for idx, item in enumerate(input_data):
        if idx in processed_ids:
            continue
        tasks.append((idx, item))

    if not tasks:
        print("All data processed.")
        return

    print(f"Remaining tasks: {len(tasks)}")

    print(f"\nInitializing vLLM (TP={args.tp})...")
    enable_lora = False
    lora_req = None

    if args.adapter:
        enable_lora = True
        lora_req = LoRARequest("bias_generator", 1, args.adapter)
        print(f"   LoRA Enabled: {args.adapter}")

    try:
        llm = LLM(
            model=args.base_model,
            tensor_parallel_size=args.tp,
            enable_lora=enable_lora,
            max_lora_rank=64,
            trust_remote_code=True,
            gpu_memory_utilization=0.9,
        )
        tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    except Exception as e:
        print(f" vLLM Init Failed: {e}")
        exit(1)

    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.9,
        max_tokens=1024,
        stop=["<|eot_id|>", "<|end_of_text|>", "<|im_end|>"]
    )

    stats = {"success": 0, "no_bias": 0}
    
    os.makedirs(os.path.dirname(args.output_file), exist_ok=True)

    with open(args.output_file, 'a', encoding='utf-8') as f_out:
        pbar = tqdm(total=len(tasks), desc="Processing")
        

        for i in range(0, len(tasks), args.batch_size):
            chunk = tasks[i : i + args.batch_size]
            chunk_indices = [t[0] for t in chunk]
            chunk_items = [t[1] for t in chunk]

            prompts = format_prompts_for_vllm(tokenizer, chunk_items)

            outputs = llm.generate(
                prompts,
                sampling_params,
                lora_request=lora_req,
                use_tqdm=False
            )

            for idx_in_batch, output in enumerate(outputs):
                original_idx = chunk_indices[idx_in_batch]
                original_item = chunk_items[idx_in_batch]
                

                raw_response = output.outputs[0].text
                

                biased_resp = clean_and_validate_output(raw_response)
                
                final_item = None
                
                if biased_resp:

                    final_item = {
                        "prompt": original_item.get('prompt', ''),
                        "chosen": original_item.get('response', '') or original_item.get('chosen', ''),
                        "rejected": biased_resp,
                        "has_bias": True,
                        "_original_idx": original_idx
                    }
                    stats["success"] += 1
                else:

                    stats["no_bias"] += 1
                    if not args.skip_no_bias:
                        final_item = {
                            "prompt": original_item.get('prompt', ''),
                            "chosen": original_item.get('response', ''),
                            "rejected": original_item.get('response', ''),
                            "has_bias": False,
                            "_original_idx": original_idx
                        }

                if final_item:
                    f_out.write(json.dumps(final_item, ensure_ascii=False) + '\n')
                
                pbar.update(1)
            

            f_out.flush()

    print("\n" + "="*60)
    print("Processing Complete")
    print(f"Success (Bias Injected): {stats['success']}")
    print(f"No Bias / Skipped:       {stats['no_bias']}")
    print(f"Output File:             {args.output_file}")
    print("="*60)

    if os.path.exists(args.output_file):
        final_data = load_jsonl(args.output_file)
        bias_count = sum(1 for x in final_data if x.get('has_bias'))
        print(f"\nTotal in file: {len(final_data)}")
        print(f"Bias Rate: {bias_count/len(final_data)*100:.1f}%" if len(final_data) > 0 else "Bias Rate: 0%")

if __name__ == '__main__':
    main()