#!/usr/bin/env python3
"""
MoCLIP training script for KIT-ML dataset
"""
import os
import sys
import torch

# Add project path to system path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from train_moclip import MoClipTrainer, get_config

def main():
    """Main function to train MoCLIP using KIT dataset"""
    
    # Set command line arguments, optimized for KIT dataset
    import argparse
    from datetime import datetime
    
    parser = argparse.ArgumentParser(description='MoCLIP Training on KIT-ML Dataset')
    
    # Basic configuration
    parser.add_argument('--exp_name', type=str, 
                       default=f'moclip_kit_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
                       help='Experiment name')
    parser.add_argument('--save_dir', type=str, 
                       default='./checkpoints/moclip_kit_training',
                       help='Model save directory')
    parser.add_argument('--device', type=str, default='cuda', help='Training device')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    
    # Data configuration (optimized for KIT)
    parser.add_argument('--dataset_name', type=str, default='kit', help='Dataset name')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--eval_batch_size', type=int, default=32, help='Evaluation batch size')
    parser.add_argument('--max_text_length', type=int, default=77, help='Maximum text length')
    
    # Model configuration (optimized for KIT dataset)
    parser.add_argument('--clip_model_name', type=str, 
                       default='openai/clip-vit-large-patch14',
                       help='CLIP model name')
    parser.add_argument('--freeze_clip', action='store_true', 
                       help='Freeze CLIP parameters (recommended for fast training)')
    parser.add_argument('--clip_finetune_layers', type=int, default=2, 
                       help='Number of CLIP layers to finetune (from the last)')
    parser.add_argument('--input_dim', type=int, default=251, 
                       help='Input dimension (251 for KIT dataset)')
    parser.add_argument('--embed_dim', type=int, default=768, help='Embedding dimension')
    parser.add_argument('--num_heads', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--num_layers', type=int, default=4, help='Number of Transformer layers')
    parser.add_argument('--dim_feedforward', type=int, default=2048, help='Feedforward network dimension')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
    parser.add_argument('--max_seq_length', type=int, default=196, help='Maximum sequence length')
    parser.add_argument('--temperature', type=float, default=0.07, help='Contrastive learning temperature')
    
    # Training configuration (optimized for KIT dataset)
    parser.add_argument('--num_epochs', type=int, default=50, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay')
    parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
    
    # Evaluation and saving
    parser.add_argument('--eval_interval', type=int, default=1000, help='Evaluation interval (steps)')
    parser.add_argument('--save_interval', type=int, default=5, help='Save interval (epochs)')
    parser.add_argument('--log_interval', type=int, default=100, help='Log interval (steps)')
    
    # wandb
    parser.add_argument('--use_wandb', action='store_true', help='Use wandb for training logging')
    
    # Resume training
    parser.add_argument('--resume', type=str, default=None, help='Checkpoint path to resume training')
    
    # Fast test mode
    parser.add_argument('--debug', action='store_true', help='Debug mode with smaller dataset')
    
    config = parser.parse_args()
    
    # Print configuration information
    print("🚀 Starting MoCLIP model training on KIT-ML dataset")
    print("=" * 60)
    print(f"Experiment name: {config.exp_name}")
    print(f"Dataset: {config.dataset_name}")
    print(f"Device: {config.device}")
    print(f"Batch size: {config.batch_size}")
    print(f"Input dimension: {config.input_dim}")
    print(f"Freeze CLIP: {'Yes' if config.freeze_clip else 'No'}")
    print(f"Learning rate: {config.learning_rate}")
    print(f"Training epochs: {config.num_epochs}")
    print("=" * 60)
    
    # Check if CUDA is available
    if config.device == 'cuda' and not torch.cuda.is_available():
        print("⚠️  CUDA not available, switching to CPU")
        config.device = 'cpu'
    
    # Create trainer
    try:
        trainer = MoClipTrainer(config)
        print("✅ Trainer created successfully")
    except Exception as e:
        print(f"❌ Failed to create trainer: {e}")
        return
    
    # Load checkpoint if resume path is specified
    if config.resume:
        try:
            trainer.load_model(config.resume)
            print(f"✅ Successfully loaded checkpoint: {config.resume}")
        except Exception as e:
            print(f"❌ Failed to load checkpoint: {e}")
            return
    
    # Start training
    try:
        trainer.train()
        print("🎉 Training completed!")
    except KeyboardInterrupt:
        print("\n⚠️  Training interrupted by user")
        # Save current state
        trainer.save_model('interrupted_checkpoint.pt')
        print("💾 Saved interrupted checkpoint")
    except Exception as e:
        print(f"❌ Error during training: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main() 