#!/usr/bin/env python3
"""
SelfDebias Stage III: Online Iterative DPO Training with Fairness Regularization (FR)
(Optimized for 94G VRAM & 10k Data)

Features:
1. Custom SelfDebias Loss (MSE + DPO) + Fairness Regularization (Resource Allocation)
2. Unsloth Acceleration Support
3. DDP Stability Fixes
"""

import os
# Force offline & disable W&B logging (enable if needed)

import sys
import json
import torch
import torch.nn.functional as F
import argparse
from typing import Dict
from functools import partial

from datasets import load_dataset
from transformers import set_seed
from trl import DPOTrainer, DPOConfig
from peft import PeftModel
from unsloth import FastLanguageModel, is_bfloat16_supported, PatchDPOTrainer

# ====================================================================
# 0. Global Config & Default Hyperparameters
# ====================================================================

# Default Paths
DEFAULT_BASE_MODEL = "../ckpt/Qwen3-8B"
DEFAULT_PREV_ADAPTER = "../ckpt/S2-Qwen3-8B-96G/final_adapter"  # Default to Stage 1 adapter
DEFAULT_OUTPUT_DIR = "../ckpt/S3-5-Qwen3-8B-FR"
DEFAULT_DATA_PATH = "../data/DU/online_preference_data_iter1.jsonl"

# --- Training Hyperparameters (Optimized for 10k Data) ---
LEARNING_RATE = 5e-5       # Very low LR to prevent forgetting
NUM_EPOCHS = 1             # Increased to 2 epochs for better convergence on 10k data
GLOBAL_BATCH_SIZE = 64     # Target Global Batch Size
BETA = 0.1                 # DPO KL Penalty

# ---  Hybrid Loss Config ---
SELF_DEBIAS_ALPHA = 0.25      # [Paper] Loss = L_SC + 0.25 * L_DPO
FAIR_BETA = 0.1            # [Paper] Fairness Penalty Coefficient (Resource Allocation)

# VRAM Optimization (96G VRAM)
PER_DEVICE_BATCH_SIZE = 8  # Can be 8-16 depending on sequence length
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
GRAD_ACCUM = max(1, GLOBAL_BATCH_SIZE // (PER_DEVICE_BATCH_SIZE * n_gpus))

# LoRA Config
LORA_R = 64
LORA_ALPHA = 128
LORA_DROPOUT = 0.05
MAX_SEQ_LENGTH = 4096      # Long context for CoT
MAX_PROMPT_LENGTH = 3072   # Reserve 1024 tokens for response

# ====================================================================
# 1. Custom SelfDebias + FR Trainer (Core Logic)
# ====================================================================

class SelfDebiasFRTrainer(DPOTrainer):
    def __init__(self, *args, self_debias_alpha=0.25, fair_beta=0.1, **kwargs):
        super().__init__(*args, **kwargs)
        self.self_debias_alpha = self_debias_alpha
        self.fair_beta = fair_beta
        
        if self.accelerator.is_local_main_process:
            print(f" Trainer Mode: [SelfDebias + Fairness Regularization]")
            print(f"   - SelfDebias Alpha: {self.self_debias_alpha}")
            print(f"   - Fairness Beta:  {self.fair_beta}")

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        ref_chosen_logps: torch.FloatTensor,
        ref_rejected_logps: torch.FloatTensor,
        *args,
        **kwargs,
    ):
        """
        Implementation of the FairCorrect Objective:
        L_total = L_SC + alpha * L_DPO + Fairness_Penalty
        
        Methodology:
        We treat the implicit reward gap (logits) as a "Correction Resource".
        For samples where the Reference Model is biased (ref_rejected > ref_chosen),
        we force the model to allocate infinite resources (linear penalty) to 
        maximize the gap, preventing gradient saturation.
        """
        # 1. Calculate Implicit Reward Gap (Logits)
        # logits = (log(pi_chosen) - log(pi_rejected)) - (log(ref_chosen) - log(ref_rejected))
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = ref_chosen_logps - ref_rejected_logps
        logits = pi_logratios - ref_logratios

        # ------------------------------------------------------
        # Part A: SelfDebias Components (Self-Correction)
        # ------------------------------------------------------
        # A1. L_DPO: Standard Utility (Sigmoid)
        loss_dpo = -F.logsigmoid(self.beta * logits)

        # A2. L_SC: Confidence Calibration (MSE-like)
        # Target is to force margin to 1/beta, i.e., scaled_logits -> 1
        loss_sc = (1 - self.beta * logits) ** 2

        # ------------------------------------------------------
        # Part B: Fairness Regularization (Resource Constraint)
        # ------------------------------------------------------
        # 1. Identify Bias Magnitude in Reference Model
        # If ref_rejected > ref_chosen, this value is positive -> Strong Bias detected.
        ref_bias_magnitude = ref_rejected_logps - ref_chosen_logps
        
        # 2. Apply Dynamic Penalty
        # ReLU selects only the biased samples (Hard Negatives).
        # We multiply by (-logits) to create a linear gradient that pushes 
        # the policy to maximize the gap (logits) aggressively.
        # Unlike Sigmoid (which saturates), this linear term provides constant 
        # "resource pressure" to fix stubborn errors.
        fairness_penalty = self.fair_beta * F.relu(ref_bias_magnitude) * (-logits)

        # ------------------------------------------------------
        # Part C: Final Combination
        # ------------------------------------------------------
        final_loss = loss_sc + (self.self_debias_alpha * loss_dpo) + fairness_penalty

        # ------------------------------------------------------
        # Metrics & Monitoring
        # ------------------------------------------------------
        with torch.no_grad():
            chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps)
            rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps)
            reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
            
            # Monitor how many samples triggered the fairness penalty
            bias_activation_rate = (ref_bias_magnitude > 0).float().mean()

        return final_loss, chosen_rewards, rejected_rewards

# ====================================================================
# 2. Main Program
# ====================================================================

def main():
    parser = argparse.ArgumentParser(description='FairCorrect: SelfDebias + FR Training')
    
    # --- Path Arguments ---
    parser.add_argument('--base_model', type=str, default=DEFAULT_BASE_MODEL)
    parser.add_argument('--prev_adapter', type=str, default=DEFAULT_PREV_ADAPTER, help="Adapter from previous stage")
    parser.add_argument('--data', type=str, default=DEFAULT_DATA_PATH, help="Path to preference data jsonl")
    parser.add_argument('--output_dir', type=str, default=DEFAULT_OUTPUT_DIR)
    
    # --- Training Arguments ---
    parser.add_argument('--num_epochs', type=int, default=NUM_EPOCHS)
    parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE) 
    parser.add_argument('--batch_size', type=int, default=PER_DEVICE_BATCH_SIZE, help="Per device batch size")
    
    # --- Loss Weights ---
    parser.add_argument('--self_debias_alpha', type=float, default=SELF_DEBIAS_ALPHA)
    parser.add_argument('--fair_beta', type=float, default=FAIR_BETA)
    
    args = parser.parse_args()

    # 0. Initialization
    PatchDPOTrainer() # Unsloth Patch
    set_seed(42)
    
    is_distributed = torch.cuda.device_count() > 1

    if local_rank == 0:
        print("="*70)
        print(f" FairCorrect Training (SelfDebias + FR) Optimized")
        print(f"   Base Model: {args.base_model}")
        print(f"   Prev Adapter: {args.prev_adapter}")
        print(f"   Dataset: {args.data}")
        print(f"   LR: {args.learning_rate}")
        print(f"   Epochs: {args.num_epochs}")
        print(f"   Alpha (SelfDebias): {args.self_debias_alpha} | Beta (FR): {args.fair_beta}")
        print("="*70)

    # 1. Load Base Model (Unsloth)
    if local_rank == 0: print(f"⏳ Loading Base Model...")
    
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.base_model,
        max_seq_length=MAX_SEQ_LENGTH,
        dtype=None,
        load_in_4bit=True,
        local_files_only=True,
        device_map={"": local_rank} if is_distributed else "auto",
    )
    
    # 2. Load Previous Adapter (Stage 1 or Iter N-1)
    if os.path.exists(args.prev_adapter):
        if local_rank == 0: print(f" Loading Previous Adapter from: {args.prev_adapter}")
        model = PeftModel.from_pretrained(model, args.prev_adapter, is_trainable=True)
        
        #  Ensure LoRA params are trainable (Unsloth/Peft compatibility)
        for name, param in model.named_parameters():
            if "lora" in name:
                param.requires_grad = True
    else:
        if local_rank == 0: 
            print(f" Warning: Prev adapter not found at {args.prev_adapter}")
            print(f" Initializing FRESH LoRA (Note: This is Stage II, usually requires S1)")
        
        model = FastLanguageModel.get_peft_model(
            model,
            r=LORA_R,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            lora_alpha=LORA_ALPHA,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            use_gradient_checkpointing="unsloth",
            random_state=3407,
        )

    # 3. Data Processing
    if local_rank == 0: print(f" Loading Dataset...")
    dataset = load_dataset("json", data_files=args.data, split="train")
    if local_rank == 0: print(f"   Size: {len(dataset)}")
    
    # 4. Training Arguments
    training_args = DPOConfig(
        output_dir=args.output_dir,
        num_train_epochs=args.num_epochs,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=args.learning_rate,
        lr_scheduler_type="cosine",
        warmup_ratio=0.1, # DPO suggest 10% warmup
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=5,
        optim="adamw_8bit",
        save_strategy="steps",
        save_steps=100,
        # save_total_limit=3,
        report_to="none",
        
        # ---  DDP Core Fixes ---
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": True}, # Must be True
        ddp_find_unused_parameters=False,
        # --------------------
        
        beta=BETA,
        max_prompt_length=MAX_PROMPT_LENGTH,
        max_length=MAX_SEQ_LENGTH,
        remove_unused_columns=False, # Prevent DPO columns from being removed
        dataset_num_proc=8,
    )

    # 5. Initialize Custom Trainer
    if local_rank == 0: print(f" Initializing SelfDebiasFRTrainer...")
    trainer = SelfDebiasFRTrainer(
        model=model,
        ref_model=None, # Implicit PEFT reference (Unsloth handles automatically)
        args=training_args,
        train_dataset=dataset,
        processing_class=tokenizer,
        self_debias_alpha=args.self_debias_alpha,
        fair_beta=args.fair_beta
    )

    # 6. Train
    if local_rank == 0: print(" Starting Training...")
    trainer.train()

    # 7. Save
    if local_rank == 0:
        print(f" Saving to {args.output_dir}")
        final_save_path = os.path.join(args.output_dir, "final_adapter")
        
        model.save_pretrained(final_save_path)
        tokenizer.save_pretrained(final_save_path)
        
        # Fix adapter_config for vLLM compatibility
        config_path = os.path.join(final_save_path, "adapter_config.json")
        if os.path.exists(config_path):
            with open(config_path, 'r+') as f:
                data = json.load(f)
                if "lora_alpha" not in data: data["lora_alpha"] = LORA_ALPHA
                if "r" not in data: data["r"] = LORA_R
                f.seek(0)
                json.dump(data, f, indent=2)
                f.truncate()
            print(" Fixed adapter_config.json for vLLM compatibility")

if __name__ == '__main__':
    main()