#!/usr/bin/env python3
"""
Evaluate all splits of a trained VAE experiment and aggregate metrics.
"""
import os
import sys
import json
from pathlib import Path
from typing import Dict, List

import torch

# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from utils.model_utils import load_trained_model, load_experiment_data_splits
from utils.data_utils import get_mnist_data, get_fashion_mnist_data
from model.evaluation.evaluator import VAEEvaluator


def select_device() -> torch.device:
    requested = os.environ.get('VAE_DEVICE', None)
    if requested in {'cpu', 'cuda'}:
        if requested == 'cuda' and not torch.cuda.is_available():
            print('CUDA is not available. Falling back to CPU.')
            return torch.device('cpu')
        device = torch.device(requested)
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    if str(device) == 'cuda':
        try:
            _ = torch.tensor([0.0], device=device)
        except Exception as e:
            print(f'CUDA device check failed ({type(e).__name__}: {e}). Falling back to CPU.')
            device = torch.device('cpu')
            print('Using device: cpu')
    return device


def main() -> None:
    import argparse

    parser = argparse.ArgumentParser(description='Evaluate a VAE experiment across all splits')
    parser.add_argument('--experiment_dir', type=str, required=True,
                        help='Path to experiment base directory (containing aggregated_results.json)')
    parser.add_argument('--out', type=str, default=None,
                        help='Optional output JSON path for aggregated evaluation (default: <experiment_dir>/evaluation_aggregated.json)')
    args = parser.parse_args()

    experiment_dir = args.experiment_dir
    device = select_device()

    # Load data according to config
    import json as _json
    with open(os.path.join(experiment_dir, 'aggregated_results.json'), 'r') as _f:
        _agg = _json.load(_f)
    cfg = _agg['config']
    ds = cfg['data']['dataset']
    if ds.lower() in {"mnist"}:
        train_dataset, test_dataset = get_mnist_data()
    elif ds.lower() in {"fashion_mnist", "fashion-mnist", "fashionmnist"}:
        train_dataset, test_dataset = get_fashion_mnist_data()
    else:
        raise ValueError(f"Unknown dataset: {ds}")

    # Load data splits metadata
    splits, metadata = load_experiment_data_splits(experiment_dir)

    # Evaluate each split
    all_results: List[Dict] = []
    for split_idx in range(metadata['num_splits']):
        print(f"\nEvaluating split {split_idx+1}/{metadata['num_splits']}")
        # Load model for this split
        model, config = load_trained_model(experiment_dir, split_idx=split_idx, device=device)

        # Ensure results for this split go under the experiment base dir
        config = dict(config)
        config.setdefault('results', {})
        config['results']['save_dir'] = os.path.join(experiment_dir, f'split_{split_idx}')

        evaluator = VAEEvaluator(model, config, device)
        from torch.utils.data import DataLoader
        # Recreate train subset for this split
        from utils.model_utils import recreate_data_subsets
        train_indices, val_indices = load_experiment_data_splits(experiment_dir)[0][split_idx]
        train_subset, _ = recreate_data_subsets(train_dataset, train_indices, val_indices)
        train_loader = DataLoader(train_subset, batch_size=config['data']['batch_size'], shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=config['data']['batch_size'], shuffle=False)

        metrics = evaluator.compute_metrics(test_loader, train_loader=train_loader)
        print(f"Test loss: {metrics['test_loss']:.4f}, Train loss: {metrics.get('train_loss', float('nan')):.4f}")
        all_results.append({
            'split_idx': split_idx,
            **metrics
        })

    # Aggregate metrics
    import numpy as np
    agg = {}
    for key in [
        'test_loss', 'test_recon_loss', 'test_kl_loss',
        'train_loss', 'train_recon_loss', 'train_kl_loss',
        'gap_loss', 'gap_recon_loss', 'gap_kl_loss'
    ]:
        values = [r[key] for r in all_results]
        agg[f'avg_{key}'] = float(np.mean(values))
        agg[f'std_{key}'] = float(np.std(values))

    out_path = args.out or os.path.join(experiment_dir, 'evaluation_aggregated.json')
    with open(out_path, 'w') as f:
        json.dump({
            'individual_results': all_results,
            'average_metrics': agg
        }, f, indent=2)
    print(f"\nAggregated evaluation saved to: {out_path}")


if __name__ == '__main__':
    main()

