#!/usr/bin/env python3
"""
SelfDebias SFT Training - 96G VRAM Ultimate Edition (Bug Fixed)
 + 

Data Expansion Strategy:
1. Task 1 (Self-Correction): User: [Correction Template] -> Assistant: [GT]
2. Task 2 (Direct Gen):      User: [Target Question]     -> Assistant: [GT]

Usage:
    torchrun --nproc_per_node=4 train_1_sc.py
"""

import os
import sys
import glob
import torch
import re
import argparse
from contextlib import contextmanager

from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from transformers import set_seed

# ====================================================================

# ====================================================================

MODEL_NAME = "../ckpt/Qwen3-8B" 
DATASET_PATH = "../data/DB/train_sc.jsonl" 
OUTPUT_DIR = "../ckpt/S1-Qwen3"

PER_DEVICE_BATCH_SIZE = 24  

TARGET_GLOBAL_BATCH_SIZE = 64

n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
GRADIENT_ACCUMULATION_STEPS = max(1, TARGET_GLOBAL_BATCH_SIZE // (PER_DEVICE_BATCH_SIZE * n_gpus))

MAX_SEQ_LENGTH = 4096
PACKING = True
DATALOADER_WORKERS = 8  

LEARNING_RATE = 5e-5    
NUM_EPOCHS = 2          
WARMUP_RATIO = 0.05

LORA_RANK = 64          
LORA_ALPHA = 128        
LORA_DROPOUT = 0.05 

DEBUG_PRINT_COUNT = 0

# ====================================================================

# ====================================================================

def is_main_process():

@contextmanager
def torch_distributed_zero_first(local_rank: int):
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()

def extract_original_question(full_prompt):
    """
    [CRITICAL FIX] 
     Correction Template 
     re.search  Few-Shot 
    """
    try:

        matches = re.findall(r"QUESTION:\s*(.*?)\s*EXAMPLE RESPONSE:", full_prompt, re.DOTALL)
        
        if matches:

            return matches[-1].strip()
    except:
        return None
    return None

def formatting_expanded_sft_func(examples, tokenizer):
    """
    
    """
    global DEBUG_PRINT_COUNT
    texts = []
    
    if "prompt" not in examples or "response" not in examples:
        return {"text": []}

    prompts = examples["prompt"]
    responses = examples["response"]

    for full_prompt, gt_response in zip(prompts, responses):
        if not full_prompt or not gt_response:
            continue
            

        try:
            text_correction = tokenizer.apply_chat_template(
                [{"role": "user", "content": full_prompt}, 
                 {"role": "assistant", "content": gt_response}],
                tokenize=False,
                add_generation_prompt=False
            )
            texts.append(text_correction)
            

            if is_main_process() and DEBUG_PRINT_COUNT < 2:
                print(f"\n{'='*20} [DEBUG] Task 1 (Self-Correction) {'='*20}")
                print(text_correction[:500] + "...")
                print(f"{'='*60}\n")

        except Exception as e:
            if is_main_process() and DEBUG_PRINT_COUNT < 2:
                print(f"[Error Task 1]: {e}")

        original_q = extract_original_question(full_prompt)
        

        if original_q and len(original_q) > 5: 
            try:
                text_direct = tokenizer.apply_chat_template(
                    [{"role": "user", "content": original_q}, 
                     {"role": "assistant", "content": gt_response}],
                    tokenize=False,
                    add_generation_prompt=False
                )
                texts.append(text_direct)

                if is_main_process() and DEBUG_PRINT_COUNT < 2:
                    print(f"\n{'='*20} [DEBUG] Task 2 (Direct Gen) {'='*20}")
                    print(f" Extracted Question (Check if this is the TARGET question, not an example):")
                    print(f"'{original_q[:200]}...'") 
                    print("-" * 30)
                    print(text_direct[:500] + "...")
                    print(f"{'='*60}\n")
            except Exception as e:
                if is_main_process() and DEBUG_PRINT_COUNT < 2:
                    print(f"[Error Task 2]: {e}")
        
        if is_main_process():
            DEBUG_PRINT_COUNT += 1

    return {"text": texts}

# ====================================================================

# ====================================================================

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
    args = parser.parse_args()

    set_seed(3407)
    
    if None is not None:
        torch.cuda.set_device(local_rank)
        torch.distributed.init_process_group(backend="nccl")
    else:
        local_rank = 0

    if is_main_process():
        print("="*60)
        print(f" SelfDebias SFT (96G Optimized + BUG FIXED)")
        print(f"   Model: {MODEL_NAME}")
        print(f"   Batch Size (Per Device): {PER_DEVICE_BATCH_SIZE}")
        print(f"   Max Seq Length: {MAX_SEQ_LENGTH}")
        print("="*60)

    # 1. Load Model (Unsloth 4-bit)
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_NAME,
        max_seq_length=MAX_SEQ_LENGTH,
        dtype=None, 
        load_in_4bit=True, 
        device_map={"": local_rank}, 
    )

    # 2. Add LoRA Adapters
    model = FastLanguageModel.get_peft_model(
        model,
        r=LORA_RANK,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT, 
        bias="none",
        use_gradient_checkpointing="unsloth", 
        random_state=3407,
    )

    # 3. Process Dataset
    dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
    
    with torch_distributed_zero_first(local_rank):
        train_dataset = dataset.map(
            lambda x: formatting_expanded_sft_func(x, tokenizer),
            batched=True,
            remove_columns=dataset.column_names, 
            desc="Expanding & Formatting Data"
        )
    
    if is_main_process():
        print(f" Expanded Dataset Size: {len(train_dataset)}")

    # 4. Trainer Config
    training_args = SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        warmup_ratio=WARMUP_RATIO,
        learning_rate=LEARNING_RATE,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=10,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        report_to="none",
        save_strategy="steps",
        save_steps=50, 
        save_total_limit=3,
        ddp_find_unused_parameters=False,
        dataset_text_field="text",
        packing=PACKING,             
        dataloader_num_workers=DATALOADER_WORKERS, 
    )
    
    training_args.max_seq_length = MAX_SEQ_LENGTH

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        args=training_args,
        max_seq_length=MAX_SEQ_LENGTH, 
    )

    # 5. Training Loop
    if is_main_process():
        print("\n Training start...")

    checkpoint = None
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint.lower() == "true":
            if os.path.isdir(OUTPUT_DIR):
                checkpoints = sorted(glob.glob(os.path.join(OUTPUT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1]))
                if len(checkpoints) > 0:
                    checkpoint = checkpoints[-1]
                    if is_main_process(): print(f" Auto-resuming: {checkpoint}")
        elif os.path.exists(args.resume_from_checkpoint):
            checkpoint = args.resume_from_checkpoint

    trainer.train(resume_from_checkpoint=checkpoint)

    # 6. Save Model
    if is_main_process():
        print("\n Saving Adapter...")
        model.save_pretrained(os.path.join(OUTPUT_DIR, "final_adapter"))
        tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, "final_adapter"))
        print(" Done!")

    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    main()