#!/usr/bin/env python3
"""
Main Training Script for GDO-DPO

Usage:
    python scripts/train_gdo_dpo.py --config configs/gdo_dpo_config.yaml
"""

import argparse
import os
import sys
import yaml
import torch
import wandb
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
)
from src.core.difficulty_metrics import DifficultyMetrics
from src.core.gdo_dpo import GDODPOConfig, GDODPOTrainer
from src.data.data_loader import PreferenceDatasetLoader


def parse_args():
    parser = argparse.ArgumentParser(description="Train GDO-DPO model")
    parser.add_argument(
        "--config",
        type=str,
        default="configs/gdo_dpo_llama3_8b.yaml",
        help="Path to config file"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="outputs/gdo_dpo",
        help="Output directory"
    )
    parser.add_argument(
        "--precompute_difficulty",
        action="store_true",
        help="Precompute difficulty metrics"
    )
    return parser.parse_args()


def load_config(config_path: str) -> dict:
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


def precompute_difficulty_scores(
    config: dict,
    dataset,
    tokenizer,
    save_path: str
):
    """
    Precompute Csem and Upref for the dataset.

    Args:
        config: Configuration dictionary
        dataset: Dataset to process
        tokenizer: Tokenizer
        save_path: Path to save difficulty scores
    """
    print("\n" + "="*60)
    print("Precomputing Difficulty Scores")
    print("="*60)

    # Load reference model
    ref_model = AutoModelForCausalLM.from_pretrained(
        config['model']['reference_model'],
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    # Create difficulty metrics computer
    difficulty_metrics = DifficultyMetrics(
        reference_model=ref_model,
        tokenizer=tokenizer,
        device="cuda",
        num_samples=config.get('difficulty_computation', {}).get('num_samples', 8)
    )

    # Precompute
    results = difficulty_metrics.precompute_dataset_difficulties(
        dataset=dataset,
        save_path=save_path
    )

    print(f"\nDifficulty statistics:")
    print(f"  Csem: mean={results['Csem'].mean():.3f}, std={results['Csem'].std():.3f}")
    print(f"  Upref: mean={results['Upref'].mean():.3f}, std={results['Upref'].std():.3f}")

    return results


def main():
    args = parse_args()
    config = load_config(args.config)

    # Initialize wandb
    if config.get('wandb', {}).get('enabled', False):
        wandb.init(
            project=config['wandb'].get('project', 'gdo-dpo'),
            name=config['wandb'].get('run_name', 'gdo-dpo-run'),
            config=config
        )

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    print("\n" + "="*60)
    print("Loading Model and Tokenizer")
    print("="*60)

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config['model']['name'])
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load models
    print(f"Loading policy model: {config['model']['name']}")
    model = AutoModelForCausalLM.from_pretrained(
        config['model']['name'],
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )

    print(f"Loading reference model: {config['model']['reference_model']}")
    ref_model = AutoModelForCausalLM.from_pretrained(
        config['model']['reference_model'],
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )

    print("\n" + "="*60)
    print("Loading Dataset")
    print("="*60)

    # Load dataset
    data_loader = PreferenceDatasetLoader(
        tokenizer=tokenizer,
        max_length=config['data'].get('max_length', 512)
    )

    dataset_name = config['data']['dataset']
    if dataset_name == 'ultrafeedback':
        full_dataset = data_loader.load_ultrafeedback(
            split='train',
            num_samples=config['data'].get('num_samples', None)
        )
    elif dataset_name == 'hh-rlhf':
        full_dataset = data_loader.load_hh_rlhf(
            split='train',
            num_samples=config['data'].get('num_samples', None)
        )
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    print(f"Loaded {len(full_dataset)} samples")

    # Precompute or load difficulty scores
    difficulty_path = os.path.join(args.output_dir, 'difficulty_scores.npz')

    if args.precompute_difficulty or not os.path.exists(difficulty_path):
        difficulty_scores = precompute_difficulty_scores(
            config, full_dataset, tokenizer, difficulty_path
        )
    else:
        print(f"\nLoading precomputed difficulty scores from {difficulty_path}")
        difficulty_scores = dict(np.load(difficulty_path))

    # Attach difficulty scores to dataset
    full_dataset = data_loader.attach_difficulty_scores(
        full_dataset,
        difficulty_scores
    )

    # Train/val split
    train_dataset, eval_dataset = data_loader.create_train_val_split(
        full_dataset,
        val_ratio=config['data'].get('val_ratio', 0.05)
    )

    print(f"Train samples: {len(train_dataset)}")
    print(f"Eval samples: {len(eval_dataset)}")

    print("\n" + "="*60)
    print("Setting up GDO-DPO Trainer")
    print("="*60)

    # Create GDO-DPO config
    gdo_config = GDODPOConfig(
        tau_stable=config['gdo_dpo']['tau_stable'],
        tau_acc=config['gdo_dpo']['tau_acc'],
        delta_sem=config['gdo_dpo']['delta_sem'],
        delta_unc=config['gdo_dpo']['delta_unc'],
        layer_mid=config['gdo_dpo']['layer_mid'],
        ema_decay=config['gdo_dpo']['ema_decay'],
        eval_interval=config['gdo_dpo']['eval_interval'],
        beta=config['training']['beta'],
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=config['training']['batch_size'],
        per_device_eval_batch_size=config['training']['batch_size'],
        num_train_epochs=config['training']['num_epochs'],
        learning_rate=config['training']['learning_rate'],
        lr_scheduler_type=config['training'].get('lr_scheduler', 'cosine'),
        warmup_ratio=config['training'].get('warmup_ratio', 0.03),
        logging_steps=config['training'].get('logging_steps', 10),
        save_steps=config['training'].get('save_steps', 500),
        eval_steps=config['training'].get('eval_steps', 500),
        save_total_limit=config['training'].get('save_total_limit', 3),
        fp16=False,
        bf16=True,
        gradient_accumulation_steps=config['training'].get('gradient_accumulation_steps', 1),
        report_to="wandb" if config.get('wandb', {}).get('enabled', False) else "none",
    )

    # Create trainer
    trainer = GDODPOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        gdo_config=gdo_config,
    )

    print("\n" + "="*60)
    print("Starting Training")
    print("="*60)

    # Train
    trainer.train()

    # Save final model
    final_model_path = os.path.join(args.output_dir, "final_model")
    trainer.save_model(final_model_path)
    tokenizer.save_pretrained(final_model_path)

    # Save curriculum history
    curriculum_history_path = os.path.join(args.output_dir, "curriculum_history.npz")
    trainer.save_curriculum_history(curriculum_history_path)

    print("\n" + "="*60)
    print("Training Complete!")
    print(f"Model saved to: {final_model_path}")
    print(f"Curriculum history saved to: {curriculum_history_path}")
    print("="*60)


if __name__ == "__main__":
    import numpy as np
    main()
