#!/usr/bin/env python3
"""
Usage:
    torchrun --nproc_per_node=4 train_2_fr.py --resume_from_checkpoint True

Theory:
    FairCorrect Stage II: Fairness-Aware Preference Optimization.
    Integrates "SelfDebias" (Self-Correction) with "Fairness Regularization" (Resource Allocation).
"""

import os


import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import set_seed
from unsloth import FastLanguageModel, is_bfloat16_supported, PatchDPOTrainer
from trl import DPOTrainer, DPOConfig
from peft import PeftModel
import sys
import json

# ====================================================================

# ====================================================================

BASE_MODEL_PATH = "../ckpt/Qwen3-8B"            
STAGE1_ADAPTER_PATH = "../ckpt/S1-Qwen3/final_adapter" 
OUTPUT_DIR = "../ckpt/S2-Qwen3"  
DATA_FILE = "../data/DA/train_s2.jsonl"           

LEARNING_RATE = 5e-6    

NUM_EPOCHS = 1          

GLOBAL_BATCH_SIZE = 64  
BETA = 0.1              

SELF_DEBIAS_ALPHA = 0.25
FAIR_BETA = 0.1

PER_DEVICE_BATCH_SIZE = 8 
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
GRAD_ACCUM = max(1, GLOBAL_BATCH_SIZE // (PER_DEVICE_BATCH_SIZE * n_gpus))

LORA_R = 64             
LORA_ALPHA = 128        
LORA_DROPOUT = 0.05
MAX_SEQ_LENGTH = 4096   
MAX_PROMPT_LENGTH = 3072 

# ====================================================================
# 2. SelfDebias + FR (Resource Allocation) Trainer
# ====================================================================

class SelfDebiasFRTrainer(DPOTrainer):
    def __init__(self, *args, self_debias_alpha=0.25, fair_beta=0.2, **kwargs):
        super().__init__(*args, **kwargs)
        self.self_debias_alpha = self_debias_alpha
        self.fair_beta = fair_beta
        
        if self.accelerator.is_local_main_process:
            print(f" Trainer Mode: [SelfDebias + FR (Implicit Resource)]")
            print(f"   - SelfDebias Alpha: {self.self_debias_alpha}")
            print(f"   - Fairness Beta:  {self.fair_beta}")

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        ref_chosen_logps: torch.FloatTensor,
        ref_rejected_logps: torch.FloatTensor,
        *args,    
        **kwargs, 
    ):
        """
        Implementation of the FairCorrect Objective:
        L_total = L_SC + alpha * L_DPO + Fairness_Penalty
        
        Methodology:
        We treat the implicit reward gap (logits) as a "Correction Resource".
        For samples where the Reference Model is biased (ref_rejected > ref_chosen),
        we force the model to allocate infinite resources (linear penalty) to 
        maximize the gap, preventing gradient saturation.
        """
        # 1. Calculate Implicit Reward Gap (Logits)
        # logits = (log(pi_chosen) - log(pi_rejected)) - (log(ref_chosen) - log(ref_rejected))
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = ref_chosen_logps - ref_rejected_logps
        logits = pi_logratios - ref_logratios

        # ------------------------------------------------------
        # Part A: SelfDebias Components (Self-Correction)
        # ------------------------------------------------------
        # A1. L_DPO: Standard Utility
        loss_dpo = -F.logsigmoid(self.beta * logits)

        # A2. L_SC: Confidence Calibration (MSE-like)
        loss_sc = (1 - self.beta * logits) ** 2

        # ------------------------------------------------------
        # Part B: Fairness Regularization (Resource Constraint)
        # ------------------------------------------------------
        # 1. Identify Bias Magnitude in Reference Model
        # If ref_rejected > ref_chosen, this value is positive -> Strong Bias detected.
        ref_bias_magnitude = ref_rejected_logps - ref_chosen_logps
        
        # 2. Apply Dynamic Penalty
        # ReLU selects only the biased samples (Hard Negatives).
        # We multiply by (-logits) to create a linear gradient that pushes 
        # the policy to maximize the gap (logits) aggressively.
        # Unlike Sigmoid (which saturates), this linear term provides constant 
        # "resource pressure" to fix stubborn errors.
        fairness_penalty = self.fair_beta * F.relu(ref_bias_magnitude) * (-logits)

        # ------------------------------------------------------
        # Part C: Final Combination
        # ------------------------------------------------------
        final_loss = loss_sc + (self.self_debias_alpha * loss_dpo) + fairness_penalty

        # ------------------------------------------------------
        # Metrics & Monitoring
        # ------------------------------------------------------
        with torch.no_grad():
            chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps)
            rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps)
            reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
            
            # Monitor how many samples triggered the fairness penalty
            bias_activation_rate = (ref_bias_magnitude > 0).float().mean()

        return final_loss, chosen_rewards, rejected_rewards

# ====================================================================
# 3. Main Training Flow
# ====================================================================

def main():
    # 1. Apply Unsloth Patch (Fix DPO VRAM leak)
    PatchDPOTrainer()
    set_seed(42)
    
    
    if local_rank == 0:
        print("="*60)
        print(f" Stage II: FairCorrect Training Started")
        print(f"   Base Model: {BASE_MODEL_PATH}")
        print(f"   SFT Adapter: {STAGE1_ADAPTER_PATH}")
        print(f"   Strategy: 2 Epochs / Global BS {GLOBAL_BATCH_SIZE}")
        print("="*60)

    # 2. Load Base Model
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=BASE_MODEL_PATH,
        max_seq_length=MAX_SEQ_LENGTH,
        dtype=None,
        load_in_4bit=True,
        device_map={"": local_rank},
    )

    # 3. Load Stage 1 Adapter (if exists)
    if os.path.exists(STAGE1_ADAPTER_PATH):
        if local_rank == 0: print(f" Loading Stage 1 Adapter: {STAGE1_ADAPTER_PATH}")
        model = PeftModel.from_pretrained(model, STAGE1_ADAPTER_PATH, is_trainable=True)
        
        # Ensure gradients are enabled for LoRA layers
        for name, param in model.named_parameters():
            if "lora" in name:
                param.requires_grad = True
    else:
        if local_rank == 0: print(f" Warning: S1 Adapter not found! Initializing fresh LoRA.")
        model = FastLanguageModel.get_peft_model(
            model,
            r=LORA_R,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_alpha=LORA_ALPHA,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            use_gradient_checkpointing="unsloth",
            random_state=3407,
        )

    # 4. Prepare Dataset
    dataset = load_dataset("json", data_files=DATA_FILE, split="train")
    if local_rank == 0: print(f" Dataset Size: {len(dataset)}")

    # 5. Training Config
    training_args = DPOConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LEARNING_RATE,
        lr_scheduler_type="cosine",
        warmup_ratio=0.1,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=5,
        optim="adamw_8bit",
        weight_decay=0.01,
        max_grad_norm=1.0,
        seed=42,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,
        report_to="none",
        
        # --- DDP Stability Fixes ---
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": True},
        ddp_find_unused_parameters=False,
        
        # --- DPO Params ---
        beta=BETA,
        max_prompt_length=MAX_PROMPT_LENGTH,
        max_length=MAX_SEQ_LENGTH,
        dataset_num_proc=8,
        remove_unused_columns=False, # Critical for preserving data columns
    )

    # 6. Initialize Custom Trainer
    trainer = SelfDebiasFRTrainer(
        model=model,
        ref_model=None, # Unsloth handles ref model automatically
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        self_debias_alpha=SELF_DEBIAS_ALPHA,
        fair_beta=FAIR_BETA
    )

    # 7. Train
    if local_rank == 0: print(" Training Started...")
    trainer.train()

    # 8. Save
    if local_rank == 0:
        final_save_path = os.path.join(OUTPUT_DIR, "final_adapter")
        print(f" Saving to {final_save_path}")
        model.save_pretrained(final_save_path)
        tokenizer.save_pretrained(final_save_path)
        
        # Patch config for vLLM compatibility
        config_path = os.path.join(final_save_path, "adapter_config.json")
        if os.path.exists(config_path):
            with open(config_path, 'r+') as f:
                data = json.load(f)
                if "r" not in data: data["r"] = LORA_R
                if "lora_alpha" not in data: data["lora_alpha"] = LORA_ALPHA
                f.seek(0); json.dump(data, f, indent=2); f.truncate()
        
        print(" Done! FairCorrect Training Complete.")

if __name__ == "__main__":
    main()