#!/usr/bin/env python3
"""
Example script showing how to load trained models and data splits.
"""

import sys
import os
import torch
from pathlib import Path

# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from utils.model_utils import load_trained_model, load_experiment_data_splits, recreate_data_subsets, list_experiments
from utils.data_utils import get_mnist_data
from model.evaluation.evaluator import VAEEvaluator
from torch.utils.data import DataLoader


def main():
    """Example of loading and using trained models."""
    
    # Set device with safe fallback (env VAE_DEVICE or auto)
    requested_device = os.environ.get('VAE_DEVICE', None)
    if requested_device in {'cpu', 'cuda'}:
        if requested_device == 'cuda' and not torch.cuda.is_available():
            print("CUDA is not available. Falling back to CPU.")
            device = torch.device('cpu')
        else:
            device = torch.device(requested_device)
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if str(device) == 'cuda':
        try:
            _ = torch.tensor([0.0], device=device)
        except Exception as e:
            print(f"CUDA device check failed ({type(e).__name__}: {e}). Falling back to CPU.")
            device = torch.device('cpu')
            print("Using device: cpu")
    
    # Example experiment directory (update this path)
    experiment_dir = "results/experiments/mnist/vae_latent32_hidden512_256_128_64/train10000_beta0.1_lr0.0005"
    
    print(f"Loading experiment from: {experiment_dir}")

    # Quick existence check for aggregated results to give a clearer message
    aggregated_file = os.path.join(experiment_dir, 'aggregated_results.json')
    if not os.path.exists(aggregated_file):
        print(f"aggregated_results.json not found at: {aggregated_file}")
        print("This usually means training did not complete successfully.\n"
              "Run training (e.g., with --device cpu or VAE_DEVICE=cpu) and try again.")
        return
    
    # 1. List all experiments
    print("\n1. Listing all experiments:")
    experiments = list_experiments("results/experiments")
    for exp in experiments:
        print(f"  - {exp['experiment_dir']}")
        print(f"    Average test loss: {exp['average_metrics']['avg_test_loss']:.4f}")
    
    # 2. Load a specific model (split 0)
    print(f"\n2. Loading model from split 0:")
    model, config = load_trained_model(experiment_dir, split_idx=0, device=device)
    print(f"  Model loaded successfully!")
    print(f"  Latent dimension: {model.latent_dim}")
    print(f"  Input dimension: {model.input_dim}")
    
    # 3. Load data splits
    print(f"\n3. Loading data splits:")
    splits, metadata = load_experiment_data_splits(experiment_dir)
    print(f"  Number of splits: {len(splits)}")
    print(f"  Dataset: {metadata['dataset']}")
    print(f"  Training size: {metadata['train_size']}")
    
    # 4. Recreate data subsets for split 0
    print(f"\n4. Recreating data subsets for split 0:")
    train_dataset, test_dataset = get_mnist_data()
    train_indices, val_indices = splits[0]
    train_subset, val_subset = recreate_data_subsets(train_dataset, train_indices, val_indices)
    print(f"  Train subset size: {len(train_subset)}")
    print(f"  Validation subset size: {len(val_subset)}")
    
    # 5. Create data loaders
    print(f"\n5. Creating data loaders:")
    train_loader = DataLoader(train_subset, batch_size=128, shuffle=False)
    val_loader = DataLoader(val_subset, batch_size=128, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    print(f"  Data loaders created successfully!")
    
    # 6. Evaluate the loaded model
    print(f"\n6. Evaluating loaded model:")
    evaluator = VAEEvaluator(model, config, device)
    metrics = evaluator.compute_metrics(test_loader)
    print(f"  Test loss: {metrics['test_loss']:.4f}")
    print(f"  Reconstruction loss: {metrics['test_recon_loss']:.4f}")
    print(f"  KL loss: {metrics['test_kl_loss']:.4f}")
    
    # 7. Generate some samples
    print(f"\n7. Generating samples:")
    evaluator.generate_samples(num_samples=4, save_path="loaded_model_samples.png")
    print(f"  Samples saved to: loaded_model_samples.png")
    
    print(f"\nExample completed successfully!")


if __name__ == "__main__":
    main() 