#!/usr/bin/env python3
"""
DRAMNet Training Script

Advanced training pipeline for DRAMNet with comprehensive configuration support,
multiple training strategies, and extensive logging and monitoring capabilities.

Usage:
    python train_dramnet.py --config configs/dramnet_default.yaml
    python train_dramnet.py --config configs/dramnet_progressive.yaml --resume checkpoint.pth
"""

import argparse
import os
import sys
import yaml
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from pathlib import Path
import logging
from datetime import datetime

# Add project root to path
sys.path.append(str(Path(__file__).parent))

from models import DRAMNet
from training import DRAMNetTrainer, ProgressiveTrainer, MultiGPUTrainer
from training.datasets import DeblurDatasetManager
from utils.config import ConfigManager
from utils.logger import setup_logger
from utils.distributed import setup_distributed_training
from utils.reproducibility import set_seed

def parse_arguments():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='DRAMNet Training Pipeline')
    
    # Configuration
    parser.add_argument('--config', type=str, required=True,
                       help='Path to configuration file')
    parser.add_argument('--experiment-name', type=str, default=None,
                       help='Override experiment name from config')
    parser.add_argument('--output-dir', type=str, default='./experiments',
                       help='Root directory for experiments')
    
    # Training control
    parser.add_argument('--resume', type=str, default=None,
                       help='Path to checkpoint to resume from')
    parser.add_argument('--pretrained', type=str, default=None,
                       help='Path to pretrained model weights')
    parser.add_argument('--validate-only', action='store_true',
                       help='Run validation only (no training)')
    
    # Overrides
    parser.add_argument('--epochs', type=int, default=None,
                       help='Override number of training epochs')
    parser.add_argument('--batch-size', type=int, default=None,
                       help='Override batch size')
    parser.add_argument('--learning-rate', type=float, default=None,
                       help='Override learning rate')
    
    # Distributed training
    parser.add_argument('--distributed', action='store_true',
                       help='Enable distributed training')
    parser.add_argument('--world-size', type=int, default=1,
                       help='Number of distributed processes')
    parser.add_argument('--rank', type=int, default=0,
                       help='Rank of current process')
    parser.add_argument('--dist-url', type=str, default='tcp://localhost:23456',
                       help='Distributed training URL')
    
    # Debug and profiling
    parser.add_argument('--debug', action='store_true',
                       help='Enable debug mode')
    parser.add_argument('--profile', action='store_true',
                       help='Enable profiling')
    parser.add_argument('--dry-run', action='store_true',
                       help='Dry run (setup only, no training)')
    
    return parser.parse_args()

def load_config(config_path: str, args: argparse.Namespace) -> dict:
    """Load and process configuration."""
    config_manager = ConfigManager()
    config = config_manager.load_config(config_path)
    
    # Apply command line overrides
    if args.experiment_name:
        config['experiment_name'] = args.experiment_name
    if args.epochs:
        config['training']['epochs'] = args.epochs
    if args.batch_size:
        config['training']['batch_size'] = args.batch_size
    if args.learning_rate:
        config['training']['learning_rate'] = args.learning_rate
    
    # Set debug mode
    if args.debug:
        config['debug'] = True
        config['training']['log_interval'] = 1
        config['training']['val_interval'] = 1
    
    return config

def setup_experiment_directory(config: dict, output_dir: str) -> Path:
    """Setup experiment directory structure."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_name = config.get('experiment_name', 'dramnet_experiment')
    
    exp_dir = Path(output_dir) / f"{exp_name}_{timestamp}"
    exp_dir.mkdir(parents=True, exist_ok=True)
    
    # Create subdirectories
    (exp_dir / 'checkpoints').mkdir(exist_ok=True)
    (exp_dir / 'logs').mkdir(exist_ok=True)
    (exp_dir / 'configs').mkdir(exist_ok=True)
    (exp_dir / 'outputs').mkdir(exist_ok=True)
    (exp_dir / 'visualizations').mkdir(exist_ok=True)
    
    # Save config
    with open(exp_dir / 'configs' / 'config.yaml', 'w') as f:
        yaml.dump(config, f, default_flow_style=False)
    
    return exp_dir

def create_model(config: dict, device: torch.device) -> DRAMNet:
    """Create and initialize DRAMNet model."""
    model_config = config.get('model', {})
    
    model = DRAMNet(
        width=model_config.get('width', 64),
        img_channel=model_config.get('img_channel', 3),
        out_channels=model_config.get('out_channels', 3),
        num_dram_blocks=model_config.get('num_dram_blocks', 4),
        depth_encoder=model_config.get('depth_encoder', 'vitl'),
        enc_blk_nums=model_config.get('enc_blk_nums', [1, 1, 1, 4]),
        enable_early_exit=model_config.get('enable_early_exit', True),
        blur_threshold=model_config.get('blur_threshold', 0.3),
        confidence_threshold=model_config.get('confidence_threshold', 0.8)
    )
    
    model = model.to(device)
    
    # Initialize weights if specified
    if 'weight_init' in model_config:
        init_method = model_config['weight_init']
        if init_method == 'xavier':
            model.apply(xavier_init_weights)
        elif init_method == 'kaiming':
            model.apply(kaiming_init_weights)
    
    return model

def xavier_init_weights(m):
    """Xavier weight initialization."""
    if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

def kaiming_init_weights(m):
    """Kaiming weight initialization."""
    if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
        torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

def create_datasets(config: dict):
    """Create training and validation datasets."""
    dataset_manager = DeblurDatasetManager(config.get('data', {}))
    
    train_loader = dataset_manager.get_train_loader()
    val_loader = dataset_manager.get_val_loader()
    test_loader = dataset_manager.get_test_loader()
    
    return train_loader, val_loader, test_loader

def create_trainer(
    model: DRAMNet,
    train_loader,
    val_loader,
    config: dict,
    device: torch.device,
    experiment_dir: Path,
    args: argparse.Namespace
):
    """Create appropriate trainer based on configuration."""
    training_config = config.get('training', {})
    
    # Select trainer type
    trainer_type = training_config.get('trainer_type', 'standard')
    
    if trainer_type == 'progressive':
        trainer_class = ProgressiveTrainer
    elif trainer_type == 'multi_gpu' or args.distributed:
        trainer_class = MultiGPUTrainer
    else:
        trainer_class = DRAMNetTrainer
    
    trainer = trainer_class(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=training_config,
        device=device,
        experiment_dir=str(experiment_dir)
    )
    
    return trainer

def load_checkpoint(model, trainer, checkpoint_path: str, device: torch.device):
    """Load model and trainer state from checkpoint."""
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Load trainer state
    if hasattr(trainer, 'optimizer') and 'optimizer_state_dict' in checkpoint:
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    if hasattr(trainer, 'scheduler') and 'scheduler_state_dict' in checkpoint:
        trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    if hasattr(trainer, 'scaler') and 'scaler_state_dict' in checkpoint:
        trainer.scaler.load_state_dict(checkpoint['scaler_state_dict'])
    
    # Set epoch
    start_epoch = checkpoint.get('epoch', 0) + 1
    trainer.current_epoch = start_epoch
    
    print(f"Resumed from checkpoint: {checkpoint_path} (epoch {start_epoch})")
    
    return start_epoch

def main_worker(rank: int, world_size: int, args: argparse.Namespace):
    """Main worker function for distributed training."""
    # Setup distributed training if needed
    if args.distributed:
        setup_distributed_training(rank, world_size, args.dist_url)
    
    # Load configuration
    config = load_config(args.config, args)
    
    # Set random seed for reproducibility
    set_seed(config.get('seed', 42) + rank)
    
    # Setup device
    if torch.cuda.is_available():
        if args.distributed:
            device = torch.device(f'cuda:{rank}')
            torch.cuda.set_device(device)
        else:
            device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    # Setup experiment directory (only on main process)
    if not args.distributed or rank == 0:
        experiment_dir = setup_experiment_directory(config, args.output_dir)
        
        # Setup logging
        setup_logger(
            log_dir=experiment_dir / 'logs',
            log_level=logging.DEBUG if args.debug else logging.INFO,
            rank=rank
        )
    else:
        experiment_dir = None
    
    # Broadcast experiment directory to all processes
    if args.distributed:
        # In real distributed setup, you'd broadcast the experiment_dir path
        pass
    
    # Create model
    print(f"Creating DRAMNet model on device: {device}")
    model = create_model(config, device)
    
    if args.dry_run:
        print("Dry run completed. Exiting.")
        return
    
    # Create datasets
    print("Creating datasets...")
    train_loader, val_loader, test_loader = create_datasets(config)
    
    # Create trainer
    print("Creating trainer...")
    trainer = create_trainer(
        model, train_loader, val_loader, config, 
        device, experiment_dir, args
    )
    
    # Load checkpoint if resuming
    start_epoch = 0
    if args.resume:
        start_epoch = load_checkpoint(model, trainer, args.resume, device)
    
    # Load pretrained weights if specified
    if args.pretrained and not args.resume:
        print(f"Loading pretrained weights from: {args.pretrained}")
        pretrained_state = torch.load(args.pretrained, map_location=device)
        model.load_state_dict(pretrained_state['model_state_dict'], strict=False)
    
    # Validation only mode
    if args.validate_only:
        print("Running validation only...")
        val_metrics = trainer.validate()
        print(f"Validation results: {val_metrics}")
        return
    
    # Start training
    print("Starting training...")
    try:
        trainer.train()
        print("Training completed successfully!")
        
    except KeyboardInterrupt:
        print("Training interrupted by user")
        if hasattr(trainer, 'save_checkpoint'):
            print("Saving checkpoint before exit...")
            trainer.save_checkpoint({}, is_best=False)
    
    except Exception as e:
        print(f"Training failed with error: {e}")
        raise
    
    finally:
        # Cleanup distributed training
        if args.distributed:
            dist.destroy_process_group()

def main():
    """Main entry point."""
    args = parse_arguments()
    
    print("DRAMNet Training Pipeline")
    print("=" * 50)
    print(f"Config: {args.config}")
    print(f"Output directory: {args.output_dir}")
    print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    
    if args.distributed:
        print(f"Distributed training: {args.world_size} processes")
        mp.spawn(
            main_worker,
            args=(args.world_size, args),
            nprocs=args.world_size,
            join=True
        )
    else:
        main_worker(0, 1, args)

if __name__ == '__main__':
    main() 