"""
Training Demo for HEdit

This script demonstrates how to train the KV correction MLP.
"""

import argparse
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from hedit import KVCorrectionDataset, KVCorrectionMLP, MLPTrainer


def parse_args():
    parser = argparse.ArgumentParser(description="Train KV Correction MLP")
    
    # Data arguments
    parser.add_argument("--data_dir", type=str, required=True,
                        help="Directory containing training data")
    parser.add_argument("--dataset_names", type=str, nargs="+", required=True,
                        help="Names of datasets to use for training")
    parser.add_argument("--layer_idx", type=int, default=40,
                        help="Layer index to use for training")
    
    # Model arguments
    parser.add_argument("--input_dim", type=int, default=7168,
                        help="Input dimension (hidden_dim + 2*kv_dim)")
    parser.add_argument("--output_dim", type=int, default=2048,
                        help="Output dimension (2*kv_dim)")
    parser.add_argument("--hidden_dim1", type=int, default=2048,
                        help="First hidden layer dimension")
    parser.add_argument("--hidden_dim2", type=int, default=1024,
                        help="Second hidden layer dimension")
    parser.add_argument("--dropout", type=float, default=0.1,
                        help="Dropout rate")
    
    # Training arguments
    parser.add_argument("--batch_size", type=int, default=16,
                        help="Training batch size")
    parser.add_argument("--num_epochs", type=int, default=100,
                        help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-4,
                        help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=1e-5,
                        help="Weight decay for regularization")
    parser.add_argument("--train_ratio", type=float, default=0.8,
                        help="Ratio of data to use for training (vs validation)")
    
    # Output arguments
    parser.add_argument("--save_dir", type=str, default="./checkpoints",
                        help="Directory to save model checkpoints")
    parser.add_argument("--device", type=str, default="cuda",
                        help="Device to train on (cuda/cpu)")
    
    # Seed
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")
    
    return parser.parse_args()


def main():
    args = parse_args()
    
    # Set random seed
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    
    # Setup device
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Create full dataset to get sample count
    print("\n" + "="*60)
    print("Loading Dataset")
    print("="*60)
    print(f"Data directory: {args.data_dir}")
    print(f"Datasets: {args.dataset_names}")
    print(f"Layer index: {args.layer_idx}")
    
    full_dataset = KVCorrectionDataset(
        args.data_dir, 
        args.dataset_names, 
        args.layer_idx
    )
    total_samples = len(full_dataset)
    
    # Split into train and validation
    indices = list(range(total_samples))
    train_indices, val_indices = train_test_split(
        indices, 
        train_size=args.train_ratio, 
        random_state=args.seed
    )
    
    print(f"\nDataset Statistics:")
    print(f"  Total samples: {total_samples}")
    print(f"  Training samples: {len(train_indices)}")
    print(f"  Validation samples: {len(val_indices)}")
    
    # Create train and validation datasets
    train_dataset = KVCorrectionDataset(
        args.data_dir, 
        args.dataset_names, 
        args.layer_idx, 
        train_indices
    )
    val_dataset = KVCorrectionDataset(
        args.data_dir, 
        args.dataset_names, 
        args.layer_idx, 
        val_indices
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size, 
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=args.batch_size, 
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Create model
    print("\n" + "="*60)
    print("Creating Model")
    print("="*60)
    print(f"Model Configuration:")
    print(f"  Input dim: {args.input_dim}")
    print(f"  Output dim: {args.output_dim}")
    print(f"  Hidden dim 1: {args.hidden_dim1}")
    print(f"  Hidden dim 2: {args.hidden_dim2}")
    print(f"  Dropout: {args.dropout}")
    
    model = KVCorrectionMLP(
        input_dim=args.input_dim,
        output_dim=args.output_dim,
        hidden_dim1=args.hidden_dim1,
        hidden_dim2=args.hidden_dim2,
        dropout_rate=args.dropout
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nModel Parameters:")
    print(f"  Total: {total_params:,}")
    print(f"  Trainable: {trainable_params:,}")
    
    # Create trainer
    trainer = MLPTrainer(model, device=device)
    
    # Train
    print("\n" + "="*60)
    print("Training")
    print("="*60)
    print(f"Training Configuration:")
    print(f"  Epochs: {args.num_epochs}")
    print(f"  Batch size: {args.batch_size}")
    print(f"  Learning rate: {args.learning_rate}")
    print(f"  Weight decay: {args.weight_decay}")
    print(f"  Save directory: {args.save_dir}")
    
    train_losses, val_losses = trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=args.num_epochs,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        save_dir=args.save_dir
    )
    
    # Plot losses
    save_path = Path(args.save_dir) / "loss_curve.png"
    trainer.plot_losses(str(save_path))
    
    print("\n" + "="*60)
    print("Training Completed Successfully!")
    print("="*60)
    print(f"Best model saved to: {args.save_dir}/best_model.pt")
    print(f"Loss curve saved to: {save_path}")


if __name__ == "__main__":
    main()
