"""Evaluation Script for Adaptive Gating MetaNet Models

This script evaluates adaptive gating models using pre-computed features.
Supports evaluation of both MetaNet and Atlas models, with or without gating.
"""

import os
import json
import torch
import gc
from collections import defaultdict
from datetime import datetime
import numpy as np
import math

from src.args import parse_arguments
from src.adaptive_gating_metanet import AdaptiveGatingMetaNet


class DirectFeatureModel(torch.nn.Module):
    """Direct feature model (Atlas approach) with optional gating

    Args:
        feature_dim: Feature dimension
        gating_no_metanet: Whether to use gating without MetaNet
        base_threshold: Base threshold for gating
        beta: Beta parameter for uncertainty
        uncertainty_reg: Uncertainty regularization weight
    """
    def __init__(self, feature_dim, gating_no_metanet=False, base_threshold=0.05, beta=1.0, uncertainty_reg=0.01):
        super().__init__()
        self.feature_dim = feature_dim
        self.gating_no_metanet = gating_no_metanet

        # Identity projection with dummy components for DDP
        self.projection = torch.nn.Identity()
        self.dummy_param = torch.nn.Parameter(torch.zeros(1), requires_grad=True)
        self.dummy_linear = torch.nn.Linear(feature_dim, feature_dim)
        torch.nn.init.zeros_(self.dummy_linear.weight)
        torch.nn.init.zeros_(self.dummy_linear.bias)

        # Adding gating mechanism if enabled
        if self.gating_no_metanet:
            # Initialize gating parameters
            self.log_base_threshold = torch.nn.Parameter(
                torch.tensor([math.log(max(base_threshold, 1e-5))], dtype=torch.float)
            )
            self.log_beta = torch.nn.Parameter(
                torch.tensor([math.log(max(beta * 0.95, 1e-5))], dtype=torch.float)
            )

            # Register monitoring buffers
            self.register_buffer('initial_base_threshold', torch.tensor([base_threshold], dtype=torch.float))
            self.register_buffer('initial_beta', torch.tensor([beta], dtype=torch.float))

            # Tracking variables
            self.uncertainty_reg = uncertainty_reg
            self._forward_count = 0
            self._reg_loss_count = 0
            self.training_mode = True

            # Storage for computed values
            self.last_uncertainties = None
            self.last_gated_features = None
            self.last_thresholds = None
            self.last_orig_features = None
            self.last_gating_mask = None

            # Uncertainty network
            self.uncertainty_net = torch.nn.Sequential(
                torch.nn.Linear(feature_dim, feature_dim // 4),
                torch.nn.ReLU(),
                torch.nn.Linear(feature_dim // 4, feature_dim),
                torch.nn.Sigmoid()
            )

    @property
    def base_threshold(self):
        """Get the actual base threshold value"""
        if self.gating_no_metanet:
            return torch.exp(self.log_base_threshold)
        return torch.tensor(0.0)

    @property
    def beta(self):
        """Get the actual beta value"""
        if self.gating_no_metanet:
            return torch.exp(self.log_beta)
        return torch.tensor(0.0)

    def compute_uncertainty(self, features):
        """Compute uncertainty based on feature characteristics"""
        # Get feature-specific uncertainty
        feature_uncertainty = self.uncertainty_net(features)

        # Add batch statistics component
        batch_std = features.std(dim=0, keepdim=True).expand(features.size(0), -1)
        batch_std = batch_std / (batch_std.max() + 1e-8)  # Normalize

        # Combine components
        combined_uncertainty = 0.7 * feature_uncertainty + 0.3 * batch_std

        # Add random component
        if self.training_mode:
            random_noise = torch.rand_like(combined_uncertainty) * 0.1
        else:
            random_noise = torch.ones_like(combined_uncertainty) * 0.05

        combined_uncertainty = combined_uncertainty + random_noise
        combined_uncertainty = combined_uncertainty.clamp(min=0.01, max=1.0)

        return combined_uncertainty

    def adaptive_gating(self, features, uncertainties):
        """Apply adaptive thresholding based on uncertainty"""
        # Get parameters
        base_threshold = self.base_threshold
        beta_val = self.beta

        # Compute adaptive thresholds
        thresholds = base_threshold * (1.0 + beta_val * uncertainties)

        # Normalize features for gating
        feature_norms = torch.norm(features, dim=1, keepdim=True)
        normalized_features = features / (feature_norms + 1e-8)
        feature_magnitudes = torch.abs(normalized_features)

        # Apply gating with smooth transition
        sigmoid_scale = 20.0  # Steepness of the sigmoid
        gating_mask = torch.sigmoid(sigmoid_scale * (feature_magnitudes - thresholds))
        gated_features = features * gating_mask

        # Store the binary gating mask
        self.last_gating_mask = (feature_magnitudes >= thresholds).float().detach()

        return gated_features, thresholds

    def forward(self, features):
        """Forward pass with optional adaptive gating"""
        if self.gating_no_metanet:
            # Apply gating to features
            self._forward_count = getattr(self, '_forward_count', 0) + 1
            self.last_orig_features = features.detach()

            # Compute uncertainty and apply gating
            uncertainties = self.compute_uncertainty(features)
            self.last_uncertainties = uncertainties
            gated_features, thresholds = self.adaptive_gating(features, uncertainties)
            self.last_gated_features = gated_features
            self.last_thresholds = thresholds

            # Add dummy computation scaled to zero
            return gated_features + self.dummy_linear(features) * 0.0
        else:
            # Identity projection + dummy computation scaled to zero
            return self.projection(features) + self.dummy_linear(features) * 0.0

    def uncertainty_regularization_loss(self):
        """Calculate regularization loss"""
        # Only meaningful for gating models during training
        return self.dummy_param.sum() * 0.0

    def get_gating_stats(self):
        """Get statistics about the gating process"""
        if not self.gating_no_metanet:
            return {}

        # Use stored values if available
        if self.last_gated_features is not None and self.last_gating_mask is not None:
            gating_ratio = self.last_gating_mask.mean().item()
        else:
            # Generate sample data for stats
            batch_size = 64
            features = torch.randn(batch_size, self.feature_dim, device=self.log_base_threshold.device)

            with torch.no_grad():
                uncertainties = self.compute_uncertainty(features)
                thresholds = self.base_threshold * (1.0 + self.beta * uncertainties)
                feature_norms = torch.norm(features, dim=1, keepdim=True)
                normalized_features = features / (feature_norms + 1e-8)
                feature_magnitudes = torch.abs(normalized_features)
                gating_mask = (feature_magnitudes >= thresholds).float()
                gating_ratio = gating_mask.mean().item()

        # Get parameter values
        base_threshold = self.base_threshold.item()
        beta = self.beta.item()
        log_base_threshold = self.log_base_threshold.item()
        log_beta = self.log_beta.item()

        # Get initial parameter values
        initial_base_threshold = self.initial_base_threshold.item()
        initial_beta = self.initial_beta.item()

        # Calculate change from initial values
        threshold_change = ((base_threshold - initial_base_threshold) / initial_base_threshold) * 100
        beta_change = ((beta - initial_beta) / initial_beta) * 100

        # Get average values
        avg_threshold = self.last_thresholds.mean().item() if hasattr(self, 'last_thresholds') and self.last_thresholds is not None else base_threshold
        avg_uncertainty = self.last_uncertainties.mean().item() if hasattr(self, 'last_uncertainties') and self.last_uncertainties is not None else 0.0

        return {
            "gating_ratio": gating_ratio,
            "avg_threshold": avg_threshold,
            "avg_uncertainty": avg_uncertainty,
            "base_threshold": base_threshold,
            "beta": beta,
            "log_base_threshold": log_base_threshold,
            "log_beta": log_beta,
            "initial_base_threshold": initial_base_threshold,
            "initial_beta": initial_beta,
            "threshold_change_percent": threshold_change,
            "beta_change_percent": beta_change,
        }


class PrecomputedFeatureDataset(torch.utils.data.Dataset):
    """Dataset for precomputed features"""
    def __init__(self, features_path, labels_path, verbose=False):
        super().__init__()

        # Load features and labels
        self.features = torch.load(features_path)
        self.labels = torch.load(labels_path)

        # Validate dimensions match
        if len(self.features) != len(self.labels):
            raise ValueError(f"Features ({len(self.features)}) and labels ({len(self.labels)}) count mismatch")

        if verbose:
            print(f"Loaded {len(self.features)} samples with feature dim {self.features.shape[1]}")

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return {
            "features": self.features[idx],
            "labels": self.labels[idx],
            "index": idx
        }


class TestOnlyFeatures:
    """Container for test-only precomputed features"""
    def __init__(self, feature_dir, batch_size=128, num_workers=4, verbose=False):
        # Verify directory exists
        if not os.path.exists(feature_dir):
            raise FileNotFoundError(f"Feature directory not found: {feature_dir}")

        # Define standard paths
        test_features_path = os.path.join(feature_dir, "test_features.pt")
        test_labels_path = os.path.join(feature_dir, "test_labels.pt")

        # If test files don't exist, try val files
        if not os.path.exists(test_features_path):
            test_features_path = os.path.join(feature_dir, "val_features.pt")
            test_labels_path = os.path.join(feature_dir, "val_labels.pt")

            if not os.path.exists(test_features_path):
                raise FileNotFoundError(f"Could not find test or val features in {feature_dir}")

        # Load test features and labels
        self.test_dataset = PrecomputedFeatureDataset(
            test_features_path,
            test_labels_path,
            verbose=verbose
        )

        # Create test loader
        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=min(num_workers, 2),
            pin_memory=True,
            drop_last=False,
            timeout=60,
        )

        # Load classnames if available
        classnames_path = os.path.join(feature_dir, "classnames.txt")
        if os.path.exists(classnames_path):
            with open(classnames_path, "r") as f:
                self.classnames = [line.strip() for line in f.readlines()]
        else:
            # Create dummy classnames
            unique_labels = torch.unique(self.test_dataset.labels)
            self.classnames = [f"class_{i}" for i in range(len(unique_labels))]


def cleanup_resources(dataset):
    """Cleanup data resources to prevent memory leaks"""
    if dataset is None:
        return

    try:
        # Clear dataset references
        if hasattr(dataset, 'test_loader'):
            dataset.test_loader = None
        if hasattr(dataset, 'test_dataset'):
            dataset.test_dataset = None
    except Exception as e:
        print(f"Warning during dataset cleanup: {e}")

    # Force garbage collection
    gc.collect()
    torch.cuda.empty_cache()


def find_model_path(model_dir, dataset_name, model_name, no_gating=False, no_metanet=False, gating_no_metanet=False, debug=False):
    """Find the model path for a given dataset and model type"""
    # Base name handling
    base_name = dataset_name.rstrip("Val")
    val_name = f"{base_name}Val"

    # Choose suffix based on model type
    if no_metanet:
        if gating_no_metanet:
            gating_suffix = "_atlas_with_gating"
        else:
            gating_suffix = "_atlas"
    elif no_gating:
        gating_suffix = "_no_gating"
    else:
        gating_suffix = "_adaptive_gating"

    # Model-specific directory
    model_specific_dir = os.path.join(model_dir, model_name)
    if os.path.exists(model_specific_dir):
        model_dir = model_specific_dir

    # Check standard naming patterns
    possible_paths = [
        # Model-specific paths with explicit naming
        os.path.join(model_dir, val_name, f"best{gating_suffix}_model.pt"),
        os.path.join(model_dir, val_name, f"best_{gating_suffix.strip('_')}_model.pt"),
        os.path.join(model_dir, base_name, f"best{gating_suffix}_model.pt"),
        os.path.join(model_dir, base_name, f"best_{gating_suffix.strip('_')}_model.pt"),
        # Standard paths as fallbacks
        os.path.join(model_dir, val_name, f"best_model.pt"),
        os.path.join(model_dir, base_name, f"best_model.pt"),
    ]

    # Try each path
    for path in possible_paths:
        if os.path.exists(path):
            if debug:
                print(f"Found model at: {path}")
            return path

    # If no model found, return descriptive error
    raise FileNotFoundError(f"Could not find model for {dataset_name} (model: {model_name})")


def evaluate_model(model_path, dataset, device, args):
    """Evaluate model on dataset

    Args:
        model_path: Path to saved model
        dataset: Dataset with precomputed features
        device: Computation device
        args: Command line arguments

    Returns:
        dict: Evaluation results
    """
    # Load model state
    state_dict = torch.load(model_path, map_location=device)

    # Extract model configuration if available
    if 'config' in state_dict:
        config = state_dict['config']
        feature_dim = config.get('feature_dim')
        model_name = config.get('model_name', args.model)
        no_metanet = config.get('no_metanet', False)
        gating_no_metanet = config.get('gating_no_metanet', False)
    else:
        # Use arguments and analyze model if no config
        batch = next(iter(dataset.test_loader))
        features = batch["features"] if isinstance(batch, dict) else batch[0]
        feature_dim = features.shape[1]
        model_name = args.model
        no_metanet = args.no_metanet
        gating_no_metanet = args.gating_no_metanet

    # Create appropriate model
    if no_metanet:
        # Atlas model (with or without gating)
        if gating_no_metanet:
            model = DirectFeatureModel(
                feature_dim=feature_dim,
                gating_no_metanet=True,
                base_threshold=config.get('base_threshold', args.base_threshold),
                beta=config.get('beta', args.beta),
                uncertainty_reg=config.get('uncertainty_reg', args.uncertainty_reg)
            )
        else:
            model = DirectFeatureModel(
                feature_dim=feature_dim,
                gating_no_metanet=False
            )
    else:
        # MetaNet model (with or without gating)
        no_gating = config.get('no_gating', args.no_gating) if 'config' in state_dict else args.no_gating

        if no_gating:
            # Set minimal threshold values for no-gating
            base_threshold = 1e-9
            beta = 1e-9
            uncertainty_reg = 0.0
        else:
            base_threshold = config.get('base_threshold', args.base_threshold)
            beta = config.get('beta', args.beta)
            uncertainty_reg = config.get('uncertainty_reg', args.uncertainty_reg)

        model = AdaptiveGatingMetaNet(
            feature_dim=feature_dim,
            task_vectors=config.get('num_task_vectors', args.num_task_vectors),
            blockwise=config.get('blockwise', args.blockwise_coef),
            base_threshold=base_threshold,
            beta=beta,
            uncertainty_reg=uncertainty_reg
        )

    # Set evaluation mode
    if hasattr(model, 'training_mode'):
        model.training_mode = False

    # Load model weights
    if 'meta_net' in state_dict:
        model.load_state_dict(state_dict['meta_net'])
    else:
        # Try finding with different key patterns
        key_patterns = [
            'module.meta_net.',
            'meta_net.',
            'module.metanet.',
            'metanet.',
            'model.meta_net.',
            'model.image_encoder.meta_net.',
            # For Atlas models
            'dummy_param',
            'dummy_linear',
            'uncertainty_net',
            'log_base_threshold',
            'log_beta',
        ]

        found_keys = False
        for pattern in key_patterns:
            pattern_keys = {k[len(pattern):]: v for k, v in state_dict.items() if k.startswith(pattern)}
            if pattern_keys:
                try:
                    model.load_state_dict(pattern_keys)
                    found_keys = True
                    break
                except Exception:
                    continue

        if not found_keys:
            # Try direct loading
            try:
                model.load_state_dict(state_dict)
            except Exception:
                # For Atlas models, we can proceed even without parameters
                if not no_metanet:
                    raise ValueError("Could not load model parameters")

    model = model.to(device)
    model.eval()

    # Create model info dict
    model_info = {
        'no_metanet': no_metanet,
        'gating_no_metanet': gating_no_metanet,
        'model_name': model_name,
    }

    if not model_info['no_metanet']:
        model_info.update({
            'no_gating': no_gating,
            'base_threshold': getattr(model, 'base_threshold', torch.tensor(0.0)).item()
                if hasattr(model, 'base_threshold') else 0.0,
            'beta': getattr(model, 'beta', torch.tensor(0.0)).item()
                if hasattr(model, 'beta') else 0.0,
        })
    elif model_info['no_metanet'] and model_info['gating_no_metanet']:
        # Get gating parameters for Atlas with gating
        model_info.update({
            'base_threshold': getattr(model, 'base_threshold', torch.tensor(0.0)).item()
                if hasattr(model, 'base_threshold') else 0.0,
            'beta': getattr(model, 'beta', torch.tensor(0.0)).item()
                if hasattr(model, 'beta') else 0.0,
        })

    # Create classifier
    num_classes = len(dataset.classnames)
    classifier = torch.nn.Linear(feature_dim, num_classes)

    # Load classifier weights
    if 'classifier' in state_dict:
        classifier.load_state_dict(state_dict['classifier'])
    else:
        # Try with different key patterns
        key_patterns = [
            'module.classification_head.',
            'classification_head.',
            'classifier.',
            'module.classifier.',
            'model.classifier.',
            'model.classification_head.',
        ]

        found_keys = False
        for pattern in key_patterns:
            pattern_keys = {k[len(pattern):]: v for k, v in state_dict.items() if k.startswith(pattern)}
            if pattern_keys:
                try:
                    classifier.load_state_dict(pattern_keys)
                    found_keys = True
                    break
                except Exception:
                    continue

    classifier = classifier.to(device)
    classifier.eval()

    # Start evaluation
    correct = 0
    total = 0
    per_class_correct = defaultdict(int)
    per_class_total = defaultdict(int)
    all_preds = []
    all_labels = []
    all_confidences = []

    with torch.no_grad():
        for batch in dataset.test_loader:
            if isinstance(batch, dict):
                features = batch["features"].to(device)
                labels = batch["labels"].to(device)
            else:
                features, labels = batch
                features = features.to(device)
                labels = labels.to(device)

            # Forward pass
            transformed_features = model(features)
            outputs = classifier(transformed_features)

            # Get predictions and confidences
            probabilities = torch.softmax(outputs, dim=1)
            confidences, predicted = torch.max(probabilities, dim=1)

            # Compute accuracy
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update per-class metrics
            for i, label in enumerate(labels):
                label_idx = label.item()
                prediction = predicted[i].item()

                per_class_total[label_idx] += 1
                if prediction == label_idx:
                    per_class_correct[label_idx] += 1

            # Store for later analysis
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_confidences.extend(confidences.cpu().numpy())

    # Calculate overall accuracy
    accuracy = correct / total if total > 0 else 0.0

    # Calculate per-class accuracy
    per_class_acc = {}
    per_class_report = []

    for cls_idx in range(len(dataset.classnames)):
        cls_name = dataset.classnames[cls_idx]
        if per_class_total[cls_idx] > 0:
            cls_acc = per_class_correct[cls_idx] / per_class_total[cls_idx]
            per_class_acc[cls_name] = float(cls_acc)
            per_class_report.append({
                'class_id': cls_idx,
                'class_name': cls_name,
                'accuracy': float(cls_acc),
                'correct': per_class_correct[cls_idx],
                'total': per_class_total[cls_idx]
            })

    # Calculate confidence statistics
    all_confidences = np.array(all_confidences)
    confidence_stats = {
        'mean': float(np.mean(all_confidences)),
        'median': float(np.median(all_confidences)),
        'min': float(np.min(all_confidences)),
        'max': float(np.max(all_confidences)),
        'std': float(np.std(all_confidences))
    }

    # Determine model type from parameters
    if model_info['no_metanet']:
        if model_info['gating_no_metanet']:
            model_type = "Atlas_WithGating"
        else:
            model_type = "Atlas"
    elif model_info.get('no_gating', False):
        model_type = "MetaNet_NoGating"
    else:
        model_type = "AdaptiveGating"

    # Get gating stats if applicable
    gating_stats = None
    if hasattr(model, 'get_gating_stats'):
        gating_stats = model.get_gating_stats()

    # Compile results
    results = {
        'accuracy': accuracy,
        'num_correct': correct,
        'num_samples': total,
        'per_class_accuracy': per_class_acc,
        'per_class_report': per_class_report,
        'confidence_stats': confidence_stats,
        'config': {
            'feature_dim': feature_dim,
            'no_metanet': model_info['no_metanet'],
            'gating_no_metanet': model_info['gating_no_metanet'] if model_info['no_metanet'] else False,
            'model_name': model_info['model_name'],
        },
        'model_path': model_path,
        'model_type': model_type,
        'evaluation_timestamp': datetime.now().isoformat(),
    }

    # Add MetaNet specific information if applicable
    if not model_info['no_metanet']:
        results['config'].update({
            'base_threshold': model_info.get('base_threshold', 0.0),
            'beta': model_info.get('beta', 0.0),
            'no_gating': model_info.get('no_gating', False),
        })

        # Add computed gating ratio for models with gating
        if not model_info.get('no_gating', False) and gating_stats:
            results['computed_gating_ratio'] = gating_stats.get('gating_ratio', 0.0)

    # Add Atlas with Gating specific information if applicable
    elif model_info['no_metanet'] and model_info['gating_no_metanet'] and gating_stats:
        results['config'].update({
            'base_threshold': model_info.get('base_threshold', 0.0),
            'beta': model_info.get('beta', 0.0),
        })
        results['computed_gating_ratio'] = gating_stats.get('gating_ratio', 0.0)

    # Add gating stats if available
    if gating_stats is not None:
        results['gating_stats'] = {
            'gating_ratio': gating_stats.get('gating_ratio', 0.0),
            'avg_threshold': gating_stats.get('avg_threshold', 0.0),
            'base_threshold': gating_stats.get('base_threshold', 0.0),
            'beta': gating_stats.get('beta', 0.0),
            'avg_uncertainty': gating_stats.get('avg_uncertainty', 0.0),
        }

    return results


def main():
    """Main evaluation function"""
    # Parse arguments
    args = parse_arguments()

    # Setup device
    device = args.device
    print(f"Using device: {device}")

    # Create save directory for results
    model_save_dir = os.path.join(args.save_dir, args.model, "evaluation_results")
    os.makedirs(model_save_dir, exist_ok=True)
    print(f"Results will be saved to: {model_save_dir}")

    # Generate descriptive suffix for results
    if args.no_metanet:
        if args.gating_no_metanet:
            config_suffix = "_atlas_with_gating"
        else:
            config_suffix = "_atlas"
    elif args.no_gating:
        config_suffix = "_no_gating"
    else:
        config_suffix = "_adaptive_gating"

    if args.blockwise_coef:
        config_suffix += "_blockwise"

    # Print configuration
    print(f"\n=== Evaluation Configuration ===")
    print(f"Model: {args.model}")
    if args.no_metanet:
        if args.gating_no_metanet:
            print(f"Model type: Atlas with Gating")
            print(f"Base threshold: {args.base_threshold:.4f}, Beta: {args.beta:.4f}")
        else:
            print(f"Model type: Atlas (No MetaNet)")
    else:
        print(f"Using MetaNet: {not args.no_metanet}")
        print(f"Blockwise coefficients: {args.blockwise_coef}")
        if args.no_gating:
            print(f"Mode: MetaNet only (no gating)")
        else:
            print(f"Default threshold: {args.base_threshold:.4f}, Beta: {args.beta:.4f}")
    print(f"Datasets to evaluate: {args.datasets}")
    print("=" * 30)

    # Store all results
    all_results = {}
    summary_results = []

    for dataset_name in args.datasets:
        print(f"\n=== Evaluating dataset {dataset_name} ===")
        dataset = None

        try:
            # Build feature directory path
            feature_dir = os.path.join(args.data_location, "precomputed_features", args.model, dataset_name)
            if not os.path.exists(feature_dir):
                # Try with Val suffix
                feature_dir = os.path.join(args.data_location, "precomputed_features", args.model, dataset_name + "Val")
                if not os.path.exists(feature_dir):
                    print(f"Features for {dataset_name} not found, skipping")
                    continue

            # Create dataset
            dataset = TestOnlyFeatures(
                feature_dir=feature_dir,
                batch_size=args.batch_size,
                num_workers=args.num_workers,
                verbose=args.debug
            )

            try:
                # Find model path
                model_path = find_model_path(
                    args.model_dir,
                    dataset_name,
                    args.model,
                    no_gating=args.no_gating,
                    no_metanet=args.no_metanet,
                    gating_no_metanet=args.gating_no_metanet,
                    debug=args.debug
                )
                print(f"Using model: {os.path.basename(model_path)}")
            except FileNotFoundError as e:
                print(f"Error: {e}")
                print(f"Skipping evaluation for {dataset_name}")
                continue

            # Evaluate model
            results = evaluate_model(
                model_path,
                dataset,
                device,
                args
            )

            # Print results
            print(f"Model: {results['config']['model_name']}")
            print(f"Accuracy: {results['accuracy'] * 100:.2f}% ({results['num_correct']}/{results['num_samples']})")

            # Print model-specific information
            if not results['config']['no_metanet'] and 'computed_gating_ratio' in results:
                print(f"Gating ratio: {results['computed_gating_ratio'] * 100:.1f}%")
            elif results['config']['no_metanet'] and results['config'].get('gating_no_metanet', False) and 'computed_gating_ratio' in results:
                print(f"Atlas with Gating - gating ratio: {results['computed_gating_ratio'] * 100:.1f}%")
            elif results['config']['no_metanet']:
                print(f"Model type: Atlas (direct features)")
            else:
                print(f"Model type: MetaNet without gating")

            # Store results
            all_results[dataset_name] = results

            # Add to summary
            summary_entry = {
                'dataset': dataset_name,
                'model_type': results['model_type'],
                'accuracy': results['accuracy'],
                'samples': results['num_samples'],
                'no_metanet': results['config']['no_metanet'],
                'model_name': results['config']['model_name'],
            }

            # Add model-specific information
            if not results['config']['no_metanet']:
                summary_entry.update({
                    'alpha': results['config'].get('base_threshold', 0.0),
                    'beta': results['config'].get('beta', 0.0),
                    'no_gating': results['config'].get('no_gating', False),
                })

                if 'computed_gating_ratio' in results:
                    summary_entry['gating_ratio'] = results['computed_gating_ratio']
            elif results['config']['no_metanet'] and results['config'].get('gating_no_metanet', False):
                summary_entry.update({
                    'alpha': results['config'].get('base_threshold', 0.0),
                    'beta': results['config'].get('beta', 0.0),
                    'gating_no_metanet': True,
                })

                if 'computed_gating_ratio' in results:
                    summary_entry['gating_ratio'] = results['computed_gating_ratio']

            summary_results.append(summary_entry)

        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {e}")
        finally:
            # Clean up dataset resources
            cleanup_resources(dataset)
            gc.collect()
            torch.cuda.empty_cache()

    # Calculate average accuracy
    if summary_results:
        avg_accuracy = sum(r['accuracy'] for r in summary_results) / len(summary_results)
        all_results['average_accuracy'] = avg_accuracy
        print(f"\nAverage accuracy across all datasets: {avg_accuracy * 100:.2f}%")

    # Save all results with model name included in filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_path = os.path.join(
        model_save_dir,
        f"evaluation_{args.model}{config_suffix}_{timestamp}.json"
    )

    with open(results_path, 'w') as f:
        json.dump(all_results, f, indent=4)

    print(f"\nAll evaluation results saved to: {results_path}")

    # Print summary table
    print("\n" + "=" * 100)
    if args.no_metanet:
        if args.gating_no_metanet:
            model_type_str = "Atlas with Gating"
        else:
            model_type_str = "Atlas (No MetaNet)"
    elif args.no_gating:
        model_type_str = "MetaNet (No Gating)"
    else:
        model_type_str = "Adaptive Gating MetaNet"

    print(f"Summary of {args.model} - {model_type_str} Models")
    print("-" * 100)

    # Group results by model type
    atlas_with_gating_results = [r for r in summary_results if r.get('no_metanet', False) and r.get('gating_no_metanet', False)]
    atlas_results = [r for r in summary_results if r.get('no_metanet', False) and not r.get('gating_no_metanet', False)]
    metanet_no_gating_results = [r for r in summary_results if not r.get('no_metanet', False) and r.get('no_gating', False)]
    adaptive_gating_results = [r for r in summary_results if not r.get('no_metanet', False) and not r.get('no_gating', False)]

    # Print Atlas with Gating results
    if atlas_with_gating_results:
        print("\nAtlas with Gating Results:")
        print(f"{'Dataset':<15} | {'Accuracy':^10} | {'Model':^10} | {'α':^8} | {'β':^8} | {'Gating %':^10}")
        print("-" * 70)

        for result in sorted(atlas_with_gating_results, key=lambda x: x['dataset']):
            dataset_field = f"{result['dataset']:<15}"
            accuracy_field = f"{result['accuracy']*100:>8.2f}%"
            model_name_field = f"{result['model_name']:<10}"
            alpha_field = f"{result.get('alpha', 0):>7.4f}"
            beta_field = f"{result.get('beta', 0):>7.4f}"
            gating_field = f"{result.get('gating_ratio', 0)*100:>8.2f}%" if 'gating_ratio' in result else "   N/A   "

            print(f"{dataset_field} | {accuracy_field:^10} | {model_name_field:^10} | {alpha_field:^8} | {beta_field:^8} | {gating_field:^10}")

    # Print Atlas results
    if atlas_results:
        print("\nAtlas (No MetaNet) Results:")
        print(f"{'Dataset':<15} | {'Accuracy':^10} | {'Model':^12}")
        print("-" * 40)

        for result in sorted(atlas_results, key=lambda x: x['dataset']):
            dataset_field = f"{result['dataset']:<15}"
            accuracy_field = f"{result['accuracy']*100:>8.2f}%"
            model_name_field = f"{result['model_name']:<12}"

            print(f"{dataset_field} | {accuracy_field:^10} | {model_name_field:^12}")

    # Print MetaNet without gating results
    if metanet_no_gating_results:
        print("\nMetaNet (No Gating) Results:")
        print(f"{'Dataset':<15} | {'Accuracy':^10} | {'Model':^12}")
        print("-" * 40)

        for result in sorted(metanet_no_gating_results, key=lambda x: x['dataset']):
            dataset_field = f"{result['dataset']:<15}"
            accuracy_field = f"{result['accuracy']*100:>8.2f}%"
            model_name_field = f"{result['model_name']:<12}"

            print(f"{dataset_field} | {accuracy_field:^10} | {model_name_field:^12}")

    # Print Adaptive Gating results
    if adaptive_gating_results:
        print("\nAdaptive Gating MetaNet Results:")
        print(f"{'Dataset':<15} | {'Accuracy':^10} | {'Model':^10} | {'α':^8} | {'β':^8} | {'Gating %':^10}")
        print("-" * 70)

        for result in sorted(adaptive_gating_results, key=lambda x: x['dataset']):
            dataset_field = f"{result['dataset']:<15}"
            accuracy_field = f"{result['accuracy']*100:>8.2f}%"
            model_name_field = f"{result['model_name']:<10}"
            alpha_field = f"{result.get('alpha', 0):>7.4f}"
            beta_field = f"{result.get('beta', 0):>7.4f}"
            gating_field = f"{result.get('gating_ratio', 0)*100:>8.2f}%" if 'gating_ratio' in result else "   N/A   "

            print(f"{dataset_field} | {accuracy_field:^10} | {model_name_field:^10} | {alpha_field:^8} | {beta_field:^8} | {gating_field:^10}")

    print("=" * 100)


if __name__ == "__main__":
    main()