#!/usr/bin/env python3
"""
SelfDebias Stage I: Bias Mitigation Data Generation (vLLM Version) - Incremental Save

“”SFT

"""

import os
import json
import argparse
from typing import List, Dict
from tqdm import tqdm

# vLLM Imports
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer

# ====================================================================

# ====================================================================

BASE_MODEL_NAME = "../ckpt/Qwen3-8B"
LORA_ADAPTER_PATH = "../ckpt/V0-Qwen3/final_adapter"
INPUT_FILE = "../data/DB/train_s1.jsonl"
OUTPUT_FILE = "../data/DB/train_sc.jsonl"

TEMPERATURE = 0.8 
TOP_P = 0.9
MAX_NEW_TOKENS = 1024
TP_SIZE = 4 

SAVE_CHUNK_SIZE = 100 

BIAS_CORRECTION_TEMPLATE = """Below is a QUESTION from a user and an EXAMPLE RESPONSE generated by a model.
Your task is to critically review the example and provide an improved, rigorous, and unbiased RESPONSE.

When generating the new response, strictly adhere to the following evaluation criteria:

1. Fairness & Objectivity: Identify and remove any potential bias, stereotypes, or unfair assumptions related to protected attributes (e.g., gender, race, age, religion, nationality, etc.). The response must remain neutral and objective.
2. Avoid Stereotypes: Do not generalize traits or behaviors to specific groups. Treat individuals as individuals.
3. Helpfulness: While being safe and unbiased, ensure the response still directly answers the user's question without being overly preachy or refusing to answer unnecessarily.
4. Logic & Accuracy: Ensure the reasoning steps are logically sound and not based on prejudiced premises.

You MUST follow this exact output format:
<think>
(Write your step-by-step reasoning here. First, critique the EXAMPLE RESPONSE for any social biases or stereotypes. Then, plan how to construct a fair and unbiased answer.)
</think>

[Answer]
(Write the final, corrected response here.)

QUESTION: 
{Question}  

EXAMPLE RESPONSE: 
{Example_Response}

RESPONSE:
"""

# ====================================================================

# ====================================================================

def load_data(file_path: str) -> List[Dict]:
    """ JSONL """
    data = []
    try:
        with open(file_path, 'r', encoding='utf-8-sig') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if not line: continue
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError:
                    continue
    except FileNotFoundError:
        print(f" Error: File not found: {file_path}")
        return []
    return data

def format_prompts_for_draft(tokenizer, examples: List[Dict]) -> List[str]:
    """ Chat Template"""
    formatted_prompts = []
    for ex in examples:
        if 'prompt' in ex:
            content = ex['prompt']
        elif 'conversations' in ex:
            content = ex['conversations'][-2]['value']
        else:
            content = ""

        messages = [{"role": "user", "content": content}]
        text = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        formatted_prompts.append(text)
    return formatted_prompts

def get_processed_count(output_file: str) -> int:
    """"""
    if not os.path.exists(output_file):
        return 0
    count = 0
    with open(output_file, 'r', encoding='utf-8') as f:
        for _ in f:
            count += 1
    return count

# ====================================================================

# ====================================================================

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str, default=BASE_MODEL_NAME)
    parser.add_argument("--adapter", type=str, default=LORA_ADAPTER_PATH)
    parser.add_argument("--input", type=str, default=INPUT_FILE)
    parser.add_argument("--output", type=str, default=OUTPUT_FILE)
    parser.add_argument("--tp", type=int, default=TP_SIZE)
    parser.add_argument("--chunk_size", type=int, default=SAVE_CHUNK_SIZE)
    args = parser.parse_args()

    print("="*70)
    print(" Bias Mitigation Data Generation (Incremental Save)")
    print(f"   Model: {args.base_model}")
    print(f"   Chunk Size: {args.chunk_size} (Writes to disk every {args.chunk_size} samples)")
    print("="*70)

    print(f"\nLoading data from {args.input}...")
    raw_data = load_data(args.input)
    total_samples = len(raw_data)
    

    processed_count = get_processed_count(args.output)
    if processed_count > 0:
        print(f" Found existing output file with {processed_count} lines.")
        print(f"⏭ Resuming from index {processed_count}...")

        data_to_process = raw_data[processed_count:]
    else:
        data_to_process = raw_data

    if not data_to_process:
        print(" All data processed. Exiting.")
        return

    print(f"\nInitializing vLLM (TP={args.tp})...")
    llm = LLM(
        model=args.base_model,
        enable_lora=True,
        max_lora_rank=64,
        tensor_parallel_size=args.tp,
        trust_remote_code=True,
        gpu_memory_utilization=0.9,
    )
    
    lora_req = LoRARequest("bias_adapter", 1, args.adapter)
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    
    sampling_params = SamplingParams(
        temperature=TEMPERATURE,
        top_p=TOP_P,
        max_tokens=MAX_NEW_TOKENS,
        stop=["<|eot_id|>", "<|end_of_text|>", "assistant\n"]
    )

    print(f"\nStart Processing {len(data_to_process)} samples...")
    

    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    

    with open(args.output, 'a', encoding='utf-8') as f_out:
        

        pbar = tqdm(total=len(data_to_process), unit="sample")
        

        for i in range(0, len(data_to_process), args.chunk_size):

            batch_data = data_to_process[i : i + args.chunk_size]
            

            batch_prompts = format_prompts_for_draft(tokenizer, batch_data)
            

            batch_outputs = llm.generate(
                batch_prompts, 
                sampling_params, 
                lora_request=lora_req,
                use_tqdm=False 
            )
            

            for original_item, output in zip(batch_data, batch_outputs):
                

                draft_response = output.outputs[0].text.strip()
                if draft_response.lower().startswith("assistant"):
                    draft_response = draft_response[9:].strip()
                
                if not draft_response:
                    pbar.update(1)
                    continue

                if 'prompt' in original_item:
                    original_q = original_item['prompt']
                    ground_truth = original_item['response'] 
                elif 'conversations' in original_item:
                    original_q = original_item['conversations'][-2]['value']
                    ground_truth = original_item['conversations'][-1]['value']
                else:
                    pbar.update(1)
                    continue

                correction_input = BIAS_CORRECTION_TEMPLATE.format(
                    Question=original_q,
                    Example_Response=draft_response
                )
                
                sft_entry = {
                    "prompt": correction_input,
                    "response": ground_truth,
                    "metadata": {
                        "original_draft": draft_response,
                        "from_model": args.base_model
                    }
                }
                

                f_out.write(json.dumps(sft_entry, ensure_ascii=False) + '\n')
                pbar.update(1)
            

            f_out.flush()

            
        pbar.close()

    print("\n" + "="*70)
    print(f" Completed! Data saved to: {args.output}")
    print("="*70)

if __name__ == "__main__":
    main()