#!/usr/bin/env python3
"""
MNIST VAE Training Script

This script trains a VAE model on MNIST dataset with leave-one-out validation.
"""

import sys
import os
import torch
import numpy as np
from pathlib import Path
import copy

# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from utils.config import load_config, parse_args, update_config_with_args
from utils.data_utils import get_mnist_data, get_fashion_mnist_data, create_leave_one_out_splits, create_data_loaders, save_experiment_data_splits
from model.vae_models.vae import VAE
from model.vae_models.conv_vae import ConvVAE
from model.vae_models.hierarchical_vae import HierarchicalMLPVAE, HierarchicalConvVAE
from model.training.trainer import VAETrainer
from model.evaluation.evaluator import VAEEvaluator


def main():
    """Main training function."""
    # Parse arguments
    args = parse_args()
    
    # Load configuration
    config = load_config(args.config)
    config = update_config_with_args(config, args)
    
    # Set device with safe fallback
    # 1) CLI --device has priority, then env VAE_DEVICE, then auto
    requested_device = args.device if hasattr(args, 'device') and args.device else os.environ.get('VAE_DEVICE', None)
    if requested_device in {'cpu', 'cuda'}:
        if requested_device == 'cuda' and not torch.cuda.is_available():
            print("CUDA is not available. Falling back to CPU.")
            device = torch.device('cpu')
        else:
            device = torch.device(requested_device)
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if str(device) == 'cuda':
        try:
            # Quick runtime check to surface incompatibilities early
            _ = torch.tensor([0.0], device=device)
        except Exception as e:
            print(f"CUDA device check failed ({type(e).__name__}: {e}). Falling back to CPU.")
            device = torch.device('cpu')
            print("Using device: cpu")
    
    # Set random seeds for reproducibility
    torch.manual_seed(config['validation']['random_seed'])
    np.random.seed(config['validation']['random_seed'])
    
    # Create experiment-specific directories
    dataset_name = config['data']['dataset']
    # Build model name accounting for architecture
    arch = config['model'].get('arch', 'mlp')
    if arch == 'cnn':
        enc_ch = config['model'].get('encoder_channels') or []
        dec_ch = config['model'].get('decoder_channels') or []
        if enc_ch == dec_ch:
            hidden_tag = "_ch" + '_'.join(map(str, enc_ch))
        else:
            hidden_tag = "_encch" + '_'.join(map(str, enc_ch)) + "_decch" + '_'.join(map(str, dec_ch))
        obj = config['training'].get('objective', 'elbo')
        prefix = 'iwae' if obj == 'iwae' else 'vae'
        model_name = f"{prefix}_cnn_latent{config['model']['latent_dim']}{hidden_tag}"
    elif arch in {'hmlp', 'hcnn'}:
        # hierarchical name includes latent_dims list
        latent_dims = config['model'].get('latent_dims') or [config['model']['latent_dim']]
        ld_tag = '_'.join(map(str, latent_dims))
        obj = config['training'].get('objective', 'elbo')
        prefix = 'iwae' if obj == 'iwae' else 'vae'
        if arch == 'hcnn':
            enc_ch = config['model'].get('encoder_channels') or []
            dec_ch = config['model'].get('decoder_channels') or []
            hidden_tag = "_encch" + '_'.join(map(str, enc_ch)) + "_decch" + '_'.join(map(str, dec_ch))
            model_name = f"{prefix}_hcnn_latent{ld_tag}{hidden_tag}"
        else:
            enc_dims = config['model'].get('encoder_hidden_dims') or config['model']['hidden_dims']
            dec_dims = config['model'].get('decoder_hidden_dims') or config['model']['hidden_dims']
            if enc_dims == dec_dims:
                hidden_tag = "_hidden" + '_'.join(map(str, enc_dims))
            else:
                hidden_tag = "_enc" + '_'.join(map(str, enc_dims)) + "_dec" + '_'.join(map(str, dec_dims))
            model_name = f"{prefix}_hmlp_latent{ld_tag}{hidden_tag}"
    else:
        enc_dims = config['model'].get('encoder_hidden_dims') or config['model']['hidden_dims']
        dec_dims = config['model'].get('decoder_hidden_dims') or config['model']['hidden_dims']
        if enc_dims == dec_dims:
            hidden_tag = "_hidden" + '_'.join(map(str, enc_dims))
        else:
            hidden_tag = "_enc" + '_'.join(map(str, enc_dims)) + "_dec" + '_'.join(map(str, dec_dims))
        obj = config['training'].get('objective', 'elbo')
        prefix = 'iwae' if obj == 'iwae' else 'vae'
        model_name = f"{prefix}_latent{config['model']['latent_dim']}{hidden_tag}"
    # Add objective tag and iwae_k (if applicable) into experiment id for clarity
    obj = config['training'].get('objective', 'elbo')
    obj_tag = f"_{obj}"
    if obj == 'iwae':
        k = config['training'].get('iwae_k', 5)
        if isinstance(k, list):
            obj_tag += "k" + 'x'.join(map(str, k))
        else:
            obj_tag += f"k{int(k)}"
    experiment_id = f"train{config['data']['train_size']}{obj_tag}_beta{config['training']['beta']}_lr{config['training']['learning_rate']}"
    
    # Create experiment base directory
    experiment_base_dir = os.path.join(
        config['results']['save_dir'],
        dataset_name,
        model_name,
        experiment_id
    )
    
    # Create experiment-specific log directory
    experiment_log_dir = os.path.join(
        config['logging']['log_dir'],
        dataset_name,
        model_name,
        experiment_id
    )
    config['logging']['log_dir'] = experiment_log_dir
    
    # Create experiment-specific model checkpoint directory
    experiment_model_dir = os.path.join(
        config['logging']['model_dir'],
        dataset_name,
        model_name,
        experiment_id
    )
    config['logging']['model_dir'] = experiment_model_dir
    
    # Load data (flatten for MLP/hmlp, image-shaped for cnn/hcnn)
    ds = config['data']['dataset']
    flatten = arch in {'mlp', 'hmlp'}
    if ds.lower() in {"mnist"}:
        print("Loading MNIST dataset...")
        train_dataset, test_dataset = get_mnist_data(flatten=flatten)
    elif ds.lower() in {"fashion_mnist", "fashion-mnist", "fashionmnist"}:
        print("Loading FashionMNIST dataset...")
        train_dataset, test_dataset = get_fashion_mnist_data(flatten=flatten)
    else:
        raise ValueError(f"Unknown dataset: {ds}")
    
    # Create leave-one-out splits
    print("Creating leave-one-out splits...")
    splits = create_leave_one_out_splits(
        train_dataset,
        config['data']['train_size'],
        config['data']['leave_one_out_ratio'],
        config['validation']['random_seed'],
        config['validation']['num_folds']
    )
    
    print(f"Created {len(splits)} leave-one-out splits")
    print(f"Training size: {config['data']['train_size']}")
    print(f"Leave-one-out ratio: {config['data']['leave_one_out_ratio']}")
    
    # Save data splits for this experiment
    save_experiment_data_splits(splits, experiment_base_dir, config)
    
    # Train and evaluate for each split
    all_results = []
    
    for split_idx, (train_subset, val_subset) in enumerate(splits):
        print(f"\n{'='*50}")
        print(f"Training on split {split_idx + 1}/{len(splits)}")
        print(f"{'='*50}")
        
        # Create data loaders
        train_loader, val_loader, test_loader = create_data_loaders(
            train_subset, val_subset, test_dataset,
            config['data']['batch_size'],
            config['data']['num_workers']
        )
        
        # Initialize model
        if arch == 'cnn':
            model = ConvVAE(
                in_channels=config['model'].get('in_channels', 1),
                encoder_channels=config['model']['encoder_channels'],
                decoder_channels=config['model']['decoder_channels'],
                latent_dim=config['model']['latent_dim'],
            )
        elif arch == 'hcnn':
            model = HierarchicalConvVAE(
                in_channels=config['model'].get('in_channels', 1),
                encoder_channels=config['model']['encoder_channels'],
                decoder_channels=config['model']['decoder_channels'],
                latent_dims=config['model']['latent_dims'],
            )
        elif arch == 'hmlp':
            model = HierarchicalMLPVAE(
                input_dim=config['model']['input_dim'],
                hidden_dims=config['model']['hidden_dims'],
                latent_dims=config['model']['latent_dims'],
                encoder_hidden_dims=config['model'].get('encoder_hidden_dims'),
                decoder_hidden_dims=config['model'].get('decoder_hidden_dims'),
            )
        else:
            model = VAE(
                input_dim=config['model']['input_dim'],
                hidden_dims=config['model']['hidden_dims'],
                latent_dim=config['model']['latent_dim'],
                encoder_hidden_dims=config['model'].get('encoder_hidden_dims'),
                decoder_hidden_dims=config['model'].get('decoder_hidden_dims'),
            )
        
        # Create split-specific model directory
        split_model_dir = os.path.join(experiment_model_dir, f'split_{split_idx}')
        os.makedirs(split_model_dir, exist_ok=True)
        
        # Update config for this split (deep copy to avoid mutating base config)
        split_config = copy.deepcopy(config)
        split_config['logging']['model_dir'] = split_model_dir
        
        # Initialize trainer with split-specific config
        trainer = VAETrainer(model, split_config, device)
        
        # Train model
        print(f"Training VAE model...")
        training_history = trainer.train(train_loader, val_loader)
        
        # Load best model
        trainer.load_model('best_model.pth')
        
        # Create split-specific results directory
        split_results_dir = os.path.join(experiment_base_dir, f'split_{split_idx}')
        os.makedirs(split_results_dir, exist_ok=True)
        
        # Update results directory in split config
        split_config['results']['save_dir'] = split_results_dir
        
        # Evaluate model with split-specific config
        print(f"Evaluating model...")
        evaluator = VAEEvaluator(model, split_config, device)
        
        # Create evaluation report (includes train metrics and gaps)
        evaluator.create_evaluation_report(test_loader, train_loader=train_loader)
        
        # Get final metrics
        final_metrics = evaluator.compute_metrics(test_loader, train_loader=train_loader)
        final_metrics['split_idx'] = split_idx
        final_metrics['best_val_loss'] = trainer.best_val_loss
        
        all_results.append(final_metrics)
        
        # Save training history
        history_path = os.path.join(split_results_dir, 'training_history.json')
        with open(history_path, 'w') as f:
            import json
            json.dump(training_history, f, indent=2)
        
        # Close trainer
        trainer.close()
        
        print(f"Split {split_idx + 1} completed!")
    
    # Aggregate results
    print(f"\n{'='*50}")
    print("Aggregating results across all splits...")
    print(f"{'='*50}")
    
    # Calculate average metrics
    avg_metrics = {}
    for key in [
        'test_loss', 'test_recon_loss', 'test_kl_loss',
        'train_loss', 'train_recon_loss', 'train_kl_loss',
        'gap_loss', 'gap_recon_loss', 'gap_kl_loss',
        'best_val_loss'
    ]:
        values = [result[key] for result in all_results]
        avg_metrics[f'avg_{key}'] = np.mean(values)
        avg_metrics[f'std_{key}'] = np.std(values)
    
    # Save aggregated results (robust to partial runs)
    aggregated_results = {
        'individual_results': all_results,
        'average_metrics': avg_metrics,
        'config': config
    }
    
    # Save to experiment base directory
    aggregated_path = os.path.join(experiment_base_dir, 'aggregated_results.json')
    os.makedirs(experiment_base_dir, exist_ok=True)
    with open(aggregated_path, 'w') as f:
        import json
        json.dump(aggregated_results, f, indent=2)
    
    # Print summary
    print("\nResults Summary:")
    print(f"Number of splits: {len(splits)}")
    print(f"Average test loss: {avg_metrics['avg_test_loss']:.4f} ± {avg_metrics['std_test_loss']:.4f}")
    print(f"Average test recon: {avg_metrics['avg_test_recon_loss']:.4f} ± {avg_metrics['std_test_recon_loss']:.4f}")
    print(f"Average test KL: {avg_metrics['avg_test_kl_loss']:.4f} ± {avg_metrics['std_test_kl_loss']:.4f}")
    print(f"Average train loss: {avg_metrics['avg_train_loss']:.4f} ± {avg_metrics['std_train_loss']:.4f}")
    print(f"Average train recon: {avg_metrics['avg_train_recon_loss']:.4f} ± {avg_metrics['std_train_recon_loss']:.4f}")
    print(f"Average train KL: {avg_metrics['avg_train_kl_loss']:.4f} ± {avg_metrics['std_train_kl_loss']:.4f}")
    print(f"Average gap loss: {avg_metrics['avg_gap_loss']:.4f} ± {avg_metrics['std_gap_loss']:.4f}")
    print(f"Average gap recon: {avg_metrics['avg_gap_recon_loss']:.4f} ± {avg_metrics['std_gap_recon_loss']:.4f}")
    print(f"Average gap KL: {avg_metrics['avg_gap_kl_loss']:.4f} ± {avg_metrics['std_gap_kl_loss']:.4f}")
    print(f"Average best validation loss: {avg_metrics['avg_best_val_loss']:.4f} ± {avg_metrics['std_best_val_loss']:.4f}")
    
    print(f"\nResults saved to: {experiment_base_dir}")
    print(f"Logs saved to: {experiment_log_dir}")
    print(f"Models saved to: {experiment_model_dir}")
    print("Training completed successfully!")


if __name__ == "__main__":
    main() 