#!/usr/bin/env python3
"""
Local Dataset DPO Training Script for Qwen2.5-1.5B using HuggingFace DPOTrainer
"""

import os
import json
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
import wandb

# Set up wandb project
os.environ["WANDB_PROJECT"] = "qwen-dpo-training"
os.environ["WANDB_RUN_NAME"] = "qwen-local-dpo"

def format_dpo_dataset(example, tokenizer):
    """Format local dataset example into DPO format (prompt, chosen, rejected)"""
    # Extract the prompt (question part)
    prompt = example["input"].lstrip('\n')  # Remove leading \n, keep "Q: ... A:"
    
    # Format chosen and rejected responses with EOS token
    chosen = example["correct_output"] + tokenizer.eos_token
    rejected = example["incorrect_output"] + tokenizer.eos_token
    
    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected
    }

def main():
    # Use your SFT trained model as the base for DPO
    model_name = "./qwen-generator-sft-updated"  # Your SFT trained model
    output_dir = "./qwen-generator-dpo-updated"
    dataset_path = "/homes/55/sumeet/qwenma/generator_data/generator_data.json"

    print(f"Loading model from: {model_name}")
    
    # Load your SFT trained model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Ensure padding token is set (important for DPO)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load your local dataset
    print("Loading local dataset...")
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    
    # Create HuggingFace dataset
    dataset = Dataset.from_list(data)
    
    # Format the dataset for DPO
    formatted_dataset = dataset.map(
        lambda x: format_dpo_dataset(x, tokenizer), 
        remove_columns=dataset.column_names
    )
    
    # Split dataset: use 96% for train, 4% for eval (same ratio as SFT)
    train_test_split = formatted_dataset.train_test_split(test_size=0.04, seed=42)
    train_ds = train_test_split["train"]
    eval_ds = train_test_split["test"]

    print(f"Dataset loaded: {len(formatted_dataset)} total examples")
    print(f"Training examples: {len(train_ds)}")
    print(f"Evaluation examples: {len(eval_ds)}")

    # Debug: Check DPO data format
    print("="*50)
    print("DEBUGGING DPO DATA FORMAT")
    print("="*50)
    
    sample = train_ds[0]
    print(f"Prompt: {sample['prompt']}")
    print(f"Chosen: {sample['chosen'][:100]}...")
    print(f"Rejected: {sample['rejected'][:100]}...")
    print("="*50)

    # DPO training configuration - similar to your SFT setup
    training_args = DPOConfig(
        output_dir=output_dir,
        
        # Training parameters (similar to your SFT)
        num_train_epochs=5,  # 1 epoch like your SFT
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=5e-6,  # Lower LR for DPO (typically 1/2 of SFT LR)
        
        # DPO specific parameters
        beta=0.1,  # DPO temperature parameter
        loss_type="sigmoid",  # Standard DPO loss
        
        # Optimization settings
        optim="adamw_torch",
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        gradient_checkpointing=True,
        
        # Precision settings
        bf16=torch.cuda.is_bf16_supported(),
        fp16=not torch.cuda.is_bf16_supported(),
        
        # Logging
        logging_steps=10,
        report_to="wandb",
        
        # Saving strategy - save only at end like your SFT
        save_strategy="no",  # No intermediate saves
        eval_strategy="epoch",  # Evaluate at end of epoch
        
        # DPO specific length settings
        max_prompt_length=256,  # Adjust based on your prompt lengths
        max_length=768,  # Total max length for prompt + response
        
        # Performance settings
        remove_unused_columns=False,  # Important for DPO
        
        # Other settings
        seed=42,
        weight_decay=0.01,
        adam_beta1=0.9,
        adam_beta2=0.999,
        adam_epsilon=1e-8,
        
        # Reference model settings
        precompute_ref_log_probs=False,  # Use same model as reference
    )

    # Initialize DPO trainer
    trainer = DPOTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        processing_class=tokenizer,
    )

    # Log model info
    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B")

    # Train!
    print("\nStarting DPO training...")
    trainer.train()

    # Save model + tokenizer (only at the end)
    print("\nSaving final model...")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    # Final wandb log
    if trainer.state.is_world_process_zero:
        wandb.log({"dpo_training_completed": True})
        wandb.finish()

    print(f"DPO training completed! Model saved to {output_dir}")

if __name__ == "__main__":
    main()