#!/usr/bin/env python3
"""
SelfDebias SFT Training - Unsloth DDP Version (Standard SFT)
 Unsloth  Chosen/Response  (Standard SFT)

:
1. Standard: {"prompt": "...", "response": "..."}  <-- 
2. DPO:      {"prompt": "...", "chosen": "..."}
3. SelfDebias: {"x_T": "...", "Y_w": "..."}

Usage:
    torchrun --nproc_per_node=4 train_0_sft.py --resume_from_checkpoint True
"""

import os


import sys
import glob
import torch
import argparse
from contextlib import contextmanager

from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from transformers import set_seed

# ====================================================================

# ====================================================================

MODEL_NAME = "../ckpt/Qwen3-8B" 
DATASET_PATH = "../data/DA/train_v0.jsonl"
OUTPUT_DIR = "../ckpt/V0-Qwen3"

MAX_SEQ_LENGTH = 4096
# Total Batch = 48 * 2(GPUs) * 1(Accum) = 96
BATCH_SIZE = 16          
GRADIENT_ACCUMULATION_STEPS = 1
LEARNING_RATE = 1e-6     
NUM_EPOCHS = 2
WARMUP_STEPS = 10

# --- LoRA ---
LORA_RANK = 16
LORA_ALPHA = 32 
LORA_DROPOUT = 0 

# ====================================================================

# ====================================================================

def is_main_process():

@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """DDP """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()

def formatting_simple_sft_func(examples, tokenizer):
    """
    SFT 
     Prompt + Response 
    """
    texts = []
    

    if "prompt" in examples:
        prompts = examples["prompt"]
    elif "x_T" in examples:
        prompts = examples["x_T"]
    else:

        return {"text": []}

    if "response" in examples:

        responses = examples["response"]
    elif "chosen" in examples:

        responses = examples["chosen"]
    elif "Y_w" in examples:

        responses = examples["Y_w"]
    else:

        return {"text": []}

    for user_input, assistant_output in zip(prompts, responses):

        if not user_input or not assistant_output:
            continue
            
        try:

            text = tokenizer.apply_chat_template(
                [
                    {"role": "user", "content": user_input},
                    {"role": "assistant", "content": assistant_output}
                ],
                tokenize=False,
                add_generation_prompt=False
            )
        except:

            text = f"User: {user_input}\nAssistant: {assistant_output}"
            
        texts.append(text)

    return {"text": texts}

# ====================================================================

# ====================================================================

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--resume_from_checkpoint", type=str, default=None, 
                        help="Path to checkpoint directory or 'True' to auto-resume.")
    args = parser.parse_args()

    set_seed(42)
    

    if None is not None:
        torch.cuda.set_device(local_rank)
        torch.distributed.init_process_group(backend="nccl")
    else:
        local_rank = 0

    if is_main_process():
        print("="*60)
        print(f" Standard SFT Training via Unsloth (DDP)")
        print(f"   Mode: Prompt + Response")
        print("="*60)

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_NAME,
        max_seq_length=MAX_SEQ_LENGTH,
        dtype=None,
        load_in_4bit=True,
        device_map={"": local_rank}, 
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=LORA_RANK,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj"],
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT, 
        bias="none",
        use_gradient_checkpointing="unsloth", 
        random_state=3407,
    )

    if is_main_process():
        print(" Unsloth Model & LoRA Ready")
        model.print_trainable_parameters()

    dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
    
    with torch_distributed_zero_first(local_rank):
        train_dataset = dataset.map(
            lambda x: formatting_simple_sft_func(x, tokenizer),
            batched=True,
            remove_columns=dataset.column_names,
            desc="Formatting SFT Data"
        )

    training_args = SFTConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=WARMUP_STEPS,
        learning_rate=LEARNING_RATE,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=5,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        report_to="none",
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,
        ddp_find_unused_parameters=False,
        dataset_text_field="text",
        packing=False, 
    )
    
    training_args.max_seq_length = MAX_SEQ_LENGTH

    # 6. Trainer
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        args=training_args,
        max_seq_length=MAX_SEQ_LENGTH, 
    )

    if is_main_process():
        print("\n Training start...")

    checkpoint = None
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint.lower() == "true":
            if os.path.isdir(OUTPUT_DIR):
                checkpoints = sorted(glob.glob(os.path.join(OUTPUT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1]))
                if len(checkpoints) > 0:
                    checkpoint = checkpoints[-1]
                    if is_main_process(): print(f" Auto-resuming: {checkpoint}")
        elif os.path.exists(args.resume_from_checkpoint):
            checkpoint = args.resume_from_checkpoint

    trainer.train(resume_from_checkpoint=checkpoint)

    if is_main_process():
        print("\n Saving Adapter...")
        model.save_pretrained(os.path.join(OUTPUT_DIR, "final_adapter"))
        tokenizer.save_pretrained(os.path.join(OUTPUT_DIR, "final_adapter"))
        print(" Done!")

    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    main()