#!/usr/bin/env python3
"""
Local Dataset DPO Training Script for Qwen2.5-1.5B using HuggingFace DPOTrainer
"""

import os
import json
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
import wandb

# Set up wandb project
os.environ["WANDB_PROJECT"] = "qwen-dpo-training"
os.environ["WANDB_RUN_NAME"] = "qwen-local-dpo-fixed-refmodel"

def format_dpo_dataset(example, tokenizer):
    """Format local dataset example into DPO format (prompt, chosen, rejected)"""
    # Extract the prompt (question part)
    prompt = example["input"]  # Remove leading \n, keep "Q: ... A:"
    
    # Format chosen and rejected responses with EOS token
    chosen = example["correct_output"] + tokenizer.eos_token
    rejected = example["incorrect_output"] + tokenizer.eos_token
    
    return {
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected
    }

def main():
    model_name = "./qwen-refinement-sft"
    output_dir = "./qwen-refinement-dpo"
    dataset_path = "/homes/55/sumeet/qwenma/refinement_data/refinement_data.json"

    print(f"Loading model from: {model_name}")
    
    # Load your SFT trained model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
    )
    
    # Load reference model (frozen copy of SFT model for better DPO training)
    ref_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Ensure padding token is set (important for DPO)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load your local dataset
    print("Loading local dataset...")
    with open(dataset_path, 'r') as f:
        data = json.load(f)
    
    # Create HuggingFace dataset
    dataset = Dataset.from_list(data)
    
    # Format the dataset for DPO
    formatted_dataset = dataset.map(
        lambda x: format_dpo_dataset(x, tokenizer), 
        remove_columns=dataset.column_names
    )
    
    # Split dataset: use 96% for train, 4% for eval (same ratio as SFT)
    train_test_split = formatted_dataset.train_test_split(test_size=0.04, seed=42)
    train_ds = train_test_split["train"]
    eval_ds = train_test_split["test"]

    print(f"Dataset loaded: {len(formatted_dataset)} total examples")
    print(f"Training examples: {len(train_ds)}")
    print(f"Evaluation examples: {len(eval_ds)}")

    # Debug: Check DPO data format and preference strength
    print("="*50)
    print("DEBUGGING DPO DATA FORMAT AND PREFERENCE SIGNAL")
    print("="*50)
    
    sample = train_ds[0]
    print(f"Prompt: {sample['prompt']}")
    print(f"Chosen: {sample['chosen'][:100]}...")
    print(f"Rejected: {sample['rejected'][:100]}...")
    
    # Check preference distinction
    chosen_len = len(tokenizer.encode(sample['chosen']))
    rejected_len = len(tokenizer.encode(sample['rejected']))
    print(f"Chosen length: {chosen_len} tokens")
    print(f"Rejected length: {rejected_len} tokens")
    
    # Sample a few more examples to check variety
    print("\nChecking preference distinction across samples:")
    for i in range(min(3, len(train_ds))):
        sample = train_ds[i]
        chosen_words = len(sample['chosen'].split())
        rejected_words = len(sample['rejected'].split())
        print(f"Sample {i}: Chosen={chosen_words} words, Rejected={rejected_words} words")
    
    print("="*50)

    # DPO training configuration - similar to your SFT setup
    training_args = DPOConfig(
        output_dir=output_dir,
        
        num_train_epochs=5, 
        per_device_train_batch_size=2,  # Reduced to increase grad_accum
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=8, 
        learning_rate=5e-6, 
        
        beta=0.3, 
        loss_type="sigmoid", 
        
        # Optimization settings
        optim="adamw_torch",
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        gradient_checkpointing=True,
        
        # Precision settings
        bf16=torch.cuda.is_bf16_supported(),
        fp16=not torch.cuda.is_bf16_supported(),
        
        # Logging
        logging_steps=10,
        report_to="wandb",
        
        # Saving strategy - SAVE ONLY AT END
        save_strategy="no",  # No intermediate saves
        eval_strategy="epoch",  # Evaluate at end of each epoch for monitoring
        
        # DPO specific length settings
        max_prompt_length=768,  # Adjust based on your prompt lengths
        max_length=1024,  # Total max length for prompt + response
        
        # Performance settings
        remove_unused_columns=False,  # Important for DPO
        
        # Other settings
        seed=42,
        weight_decay=0.05,
        adam_beta1=0.9,
        adam_beta2=0.999,
        adam_epsilon=1e-8,
    )

    # Initialize DPO trainer with separate reference model
    trainer = DPOTrainer(
        model=model,           # Policy model (trainable)
        ref_model=ref_model,   # Reference model (frozen)
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        processing_class=tokenizer,
    )

    # Log model info
    print(f"\nPolicy Model parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
    print(f"Policy Model trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B")
    print(f"Reference Model parameters: {sum(p.numel() for p in ref_model.parameters()) / 1e9:.2f}B")
    print(f"Reference Model trainable parameters: {sum(p.numel() for p in ref_model.parameters() if p.requires_grad) / 1e9:.2f}B")
    print("(Reference model should have 0 trainable parameters!)")

    # Train!
    print("\nStarting DPO training...")
    trainer.train()

    # Save final best model (automatically loaded due to load_best_model_at_end=True)
    print("\nSaving final best model...")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    # Final wandb log
    if trainer.state.is_world_process_zero:
        wandb.log({"dpo_training_completed": True})
        wandb.finish()

    print(f"DPO training completed! Model saved to {output_dir}")

if __name__ == "__main__":
    main()