#!/usr/bin/env python3
"""
Quick Start Example for GDO-DPO

This example demonstrates the basic usage of GDO-DPO.
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import Dataset
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from src import GDODPOConfig, GDODPOTrainer, DifficultyMetrics


def create_dummy_dataset(num_samples=100):
    """Create a small dummy dataset for demonstration."""
    data = []
    for i in range(num_samples):
        data.append({
            'prompt': f"Question {i}: What is the meaning of life?",
            'chosen': f"The meaning of life is to find purpose and happiness. Answer {i}.",
            'rejected': f"I don't know. Response {i}.",
        })
    return Dataset.from_list(data)


def main():
    print("=" * 60)
    print("GDO-DPO Quick Start Example")
    print("=" * 60)

    # 1. Load a small model for demonstration
    model_name = "gpt2"  # Use a small model for quick testing
    print(f"\n1. Loading model: {model_name}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(model_name)
    ref_model = AutoModelForCausalLM.from_pretrained(model_name)

    # 2. Create dummy dataset
    print("\n2. Creating dummy dataset")
    dataset = create_dummy_dataset(num_samples=50)
    print(f"   Dataset size: {len(dataset)}")

    # 3. Precompute difficulty scores (simplified for demo)
    print("\n3. Computing difficulty scores")
    difficulty_computer = DifficultyMetrics(
        reference_model=ref_model,
        tokenizer=tokenizer,
        device="cpu",  # Use CPU for demo
        num_samples=2   # Reduced for speed
    )

    # For demo, we'll just assign random scores
    import numpy as np
    difficulty_scores = {
        'Csem': np.random.rand(len(dataset)),
        'Upref': np.random.rand(len(dataset)),
        'Rsem': np.random.rand(len(dataset)),
        'Runc': np.random.rand(len(dataset)),
    }

    # Attach to dataset
    data_list = []
    for i, sample in enumerate(dataset):
        sample_dict = dict(sample)
        sample_dict.update({
            'Csem': float(difficulty_scores['Csem'][i]),
            'Upref': float(difficulty_scores['Upref'][i]),
            'Rsem': float(difficulty_scores['Rsem'][i]),
            'Runc': float(difficulty_scores['Runc'][i]),
        })
        data_list.append(sample_dict)

    dataset = Dataset.from_list(data_list)

    # 4. Create GDO-DPO configuration
    print("\n4. Creating GDO-DPO configuration")
    gdo_config = GDODPOConfig(
        tau_stable=1.2,
        tau_acc=0.65,
        delta_sem=0.1,
        delta_unc=0.1,
        layer_mid=6,  # GPT2 has 12 layers, use 2/3
        ema_decay=0.9,
        eval_interval=10,
        beta=0.1,
    )

    # 5. Create training arguments
    print("\n5. Setting up training")
    training_args = TrainingArguments(
        output_dir="./demo_output",
        num_train_epochs=1,
        per_device_train_batch_size=2,
        learning_rate=5e-5,
        logging_steps=5,
        save_steps=100,
        report_to="none",
    )

    # 6. Create trainer
    print("\n6. Creating GDO-DPO trainer")
    trainer = GDODPOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        train_dataset=dataset,
        eval_dataset=dataset,
        tokenizer=tokenizer,
        gdo_config=gdo_config,
    )

    # 7. Train (just a few steps for demo)
    print("\n7. Starting training (demo - just a few steps)")
    print("   NOTE: This is a simplified demo. See README for full training.")

    try:
        trainer.train(max_steps=20)  # Only train for 20 steps as demo
        print("\n✅ Demo training completed successfully!")
    except Exception as e:
        print(f"\n⚠️  Demo encountered an issue: {e}")
        print("   This is expected for the quick demo.")

    # 8. Show curriculum progress
    print("\n8. Curriculum Progress:")
    if len(trainer.curriculum_history['lambda_sem']) > 0:
        print(f"   λ_sem: {trainer.curriculum_history['lambda_sem'][-1]:.3f}")
        print(f"   λ_unc: {trainer.curriculum_history['lambda_unc'][-1]:.3f}")
        if len(trainer.curriculum_history['Srep']) > 0:
            print(f"   S_rep: {trainer.curriculum_history['Srep'][-1]:.3f}")
        if len(trainer.curriculum_history['Adisc']) > 0:
            print(f"   A_disc: {trainer.curriculum_history['Adisc'][-1]:.3f}")

    print("\n" + "=" * 60)
    print("Quick Start Example Complete!")
    print("=" * 60)
    print("\nNext steps:")
    print("  - See README.md for full training instructions")
    print("  - Check configs/ for configuration examples")
    print("  - Run scripts/train_gdo_dpo.py for full training")
    print("=" * 60)


if __name__ == "__main__":
    main()
