# #!/usr/bin/env python3
# """
# Unified Model Evaluation Script
# ==============================

# This script provides comprehensive evaluation capabilities for various image reconstruction models
# including TransUNet, U-Net, ViT-UNet (TRUST), and Restormer on orthonormal inverse problems.

# Features:
# - Unified evaluation interface for all model types
# - Comprehensive metrics (PSNR, SSIM, MAE, MSE, MS-SSIM)
# - Statistical analysis with mean ± std reporting
# - Visualization and comparison capabilities
# - Individual and comparative result generation
# """

# import os
# import time
# import json
# import csv
# from collections import defaultdict
# from pathlib import Path

# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import DataLoader
# import matplotlib.pyplot as plt
# from PIL import Image

# # Import metrics
# from pytorch_msssim import ssim, ms_ssim

# # ============================================================================
# # SHARED DATASET AND UTILITIES
# # ============================================================================

# def make_gaussian_random_orthonormal_rows(h=256, w=256, seed=42):
#     """Generate orthonormal matrix for patch transformation"""
#     if seed is not None:
#         torch.manual_seed(seed)
#     A = torch.randn(h, w)
#     Q, R = torch.linalg.qr(A.T)
#     return Q.T

# class PatchwiseOrthonormalDataset:
#     """
#     Unified dataset for orthonormal transformation evaluation.
#     Handles 16x16 patch-wise orthonormal transformation for 224x224 images.
#     """
#     def __init__(self, data_dir, seed=42, verbose=False):
#         self.data_dir = data_dir
#         self.A = make_gaussian_random_orthonormal_rows(h=256, w=256, seed=seed)
        
#         self.data_path = Path(data_dir)
#         if not self.data_path.exists():
#             raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
#         image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPEG', '.JPG'}
#         self.image_files = [f for f in self.data_path.iterdir() 
#                            if f.is_file() and f.suffix in image_extensions]
        
#         if len(self.image_files) == 0:
#             raise ValueError(f"No images found in {data_dir}")
        
#         if verbose:
#             print(f"Loaded {len(self.image_files)} images from {data_dir}")
#             print(f"Using 16x16 patch-wise orthonormal transformation with matrix shape: {self.A.shape}")

#     def __len__(self):
#         return len(self.image_files)

#     def resize_min_side(self, img, min_side=224):
#         w, h = img.size
#         s = min_side / min(w, h)
#         return img.resize((int(round(w*s)), int(round(h*s))), Image.Resampling.LANCZOS)

#     def center_crop(self, img, size=224):
#         w, h = img.size
#         left = (w - size) // 2
#         top = (h - size) // 2
#         return img.crop((left, top, left + size, top + size))

#     def preprocess_image(self, img):
#         img = img.convert("RGB")
#         img_resized = self.resize_min_side(img, 224)
#         img_crop = self.center_crop(img_resized, 224)
#         x = np.array(img_crop).astype(np.float32) / 255.0
#         return x

#     def process_image_with_orthonormal_masks(self, np_img, mask_matrix):
#         img_tensor = torch.from_numpy(np_img).float()
#         if img_tensor.shape[2] == 3:
#             img_gray = img_tensor.mean(dim=2)
#         else:
#             img_gray = img_tensor
        
#         patches = img_gray.unfold(0, 16, 16).unfold(1, 16, 16)
#         transformed_patches = torch.zeros(14, 14, 256)
        
#         for i in range(14):
#             for j in range(14):
#                 patch_flat = patches[i, j].flatten()
#                 transformed = mask_matrix @ patch_flat
#                 transformed_patches[i, j] = transformed
        
#         return transformed_patches

#     def reconstruct_masked_image(self, transformed_patches):
#         masked_image = torch.zeros(224, 224)
#         for i in range(14):
#             for j in range(14):
#                 transformed_patch = transformed_patches[i, j]
#                 patch_16x16 = transformed_patch.reshape(16, 16)
#                 start_h = i * 16
#                 end_h = start_h + 16
#                 start_w = j * 16
#                 end_w = start_w + 16
#                 masked_image[start_h:end_h, start_w:end_w] = patch_16x16
#         return masked_image

#     def apply_patchwise_orthonormal_transform(self, x):
#         y_channels = []
#         for c in range(3):
#             single_channel = x[..., c]
#             transformed_patches = self.process_image_with_orthonormal_masks(
#                 np.expand_dims(single_channel, axis=2), self.A
#             )
#             masked_channel = self.reconstruct_masked_image(transformed_patches)
#             y_channels.append(masked_channel.numpy())
        
#         y = np.stack(y_channels, axis=2)
#         y_min = y.min()
#         y_max = y.max()
#         y_norm = (y - y_min) / (y_max - y_min + 1e-8)
#         return y_norm

#     def __getitem__(self, idx):
#         img_path = self.image_files[idx]
#         try:
#             img = Image.open(img_path)
#         except Exception as e:
#             print(f"Warning: Could not load image {img_path}: {e}")
#             img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        
#         x = self.preprocess_image(img)
#         y = self.apply_patchwise_orthonormal_transform(x)
        
#         x_tensor = torch.from_numpy(x).permute(2, 0, 1)
#         y_tensor = torch.from_numpy(y).permute(2, 0, 1)
        
#         return y_tensor, x_tensor

# # ============================================================================
# # METRIC CALCULATION FUNCTIONS
# # ============================================================================

# def calculate_psnr(pred, target, max_val=1.0):
#     """Calculate Peak Signal-to-Noise Ratio"""
#     mse = torch.mean((pred - target) ** 2)
#     if mse == 0:
#         return float('inf')
#     psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
#     return psnr.item()

# def calculate_mae(pred, target):
#     """Calculate Mean Absolute Error"""
#     return torch.mean(torch.abs(pred - target)).item()

# def calculate_rmse(pred, target):
#     """Calculate Root Mean Square Error"""
#     return torch.sqrt(torch.mean((pred - target) ** 2)).item()

# def calculate_ssim(pred, target):
#     """Calculate SSIM using pytorch_msssim"""
#     return ssim(pred.unsqueeze(0), target.unsqueeze(0), 
#                 data_range=1.0, size_average=True).item()

# def calculate_ms_ssim(pred, target):
#     """Calculate MS-SSIM using pytorch_msssim"""
#     try:
#         return ms_ssim(pred.unsqueeze(0), target.unsqueeze(0),
#                        data_range=1.0, size_average=True).item()
#     except:
#         # Fallback to regular SSIM if MS-SSIM fails
#         return calculate_ssim(pred, target)

# # ============================================================================
# # UNIFIED MODEL EVALUATION CLASS
# # ============================================================================

# class UnifiedModelEvaluator:
#     """
#     Unified evaluator for all model types with consistent metric calculation.
#     """
    
#     def __init__(self, model_type, model_path, device='cuda'):
#         self.model_type = model_type
#         self.model_path = model_path
#         self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
#         self.model = None
        
#         print(f"Initializing {model_type} evaluator on {self.device}")
#         self._load_model()
    
#     def _load_model(self):
#         """Load the appropriate model based on model_type"""
#         print(f"Loading {self.model_type} model from {self.model_path}")
        
#         try:
#             if self.model_type.lower() == 'transunet':
#                 self.model = self._load_transunet()
#             elif self.model_type.lower() == 'unet':
#                 self.model = self._load_unet()
#             elif self.model_type.lower() in ['trust', 'vit_unet']:
#                 self.model = self._load_vit_unet()
#             elif self.model_type.lower() == 'restormer':
#                 self.model = self._load_restormer()
#             else:
#                 raise ValueError(f"Unsupported model type: {self.model_type}")
            
#             # Load weights and set to evaluation mode
#             checkpoint = torch.load(self.model_path, map_location=self.device)
#             if 'model_state_dict' in checkpoint:
#                 self.model.load_state_dict(checkpoint['model_state_dict'])
#             else:
#                 self.model.load_state_dict(checkpoint)
            
#             self.model.eval()
#             self.model = self.model.to(self.device)
            
#             # Count parameters
#             total_params = sum(p.numel() for p in self.model.parameters())
#             print(f"✓ {self.model_type} model loaded successfully ({total_params:,} parameters)")
            
#         except Exception as e:
#             print(f"✗ Error loading {self.model_type} model: {e}")
#             raise
    
#     def _load_transunet(self):
#         """Load TransUNet model"""
#         try:
#             # Import TransUNet components (you may need to adjust paths)
#             import sys
#             import os
            
#             # Add the directory containing your vis_transunet_full.py to path if needed
#             script_dir = os.path.dirname(os.path.abspath(__file__))
#             if script_dir not in sys.path:
#                 sys.path.append(script_dir)
            
#             # Try importing from your vis_transunet_full.py
#             from vis_transunet_full import TransUNet
            
#             return TransUNet(
#                 img_size=224, 
#                 patch_size=16, 
#                 in_channels=3, 
#                 out_channels=3,
#                 embed_dim=768, 
#                 depth=12, 
#                 num_heads=12, 
#                 mlp_ratio=4, 
#                 dropout=0.0
#             )
#         except ImportError as e:
#             print(f"Could not import TransUNet from vis_transunet_full.py: {e}")
#             print("Please ensure vis_transunet_full.py is in your Python path")
#             raise
    
#     def _load_unet(self):
#         """Load U-Net model"""
#         try:
#             # Import U-Net components
#             import sys
#             import os
            
#             script_dir = os.path.dirname(os.path.abspath(__file__))
#             if script_dir not in sys.path:
#                 sys.path.append(script_dir)
            
#             # Try importing from your vis_unet_full.py
#             from vis_unet_full import UNetForInverseProblem
            
#             return UNetForInverseProblem(
#                 n_channels=3, 
#                 n_classes=3, 
#                 bilinear=True
#             )
#         except ImportError as e:
#             print(f"Could not import UNet from vis_unet_full.py: {e}")
#             print("Please ensure vis_unet_full.py is in your Python path")
#             raise
    
#     def _load_vit_unet(self):
#         """Load ViT-UNet (TRUST) model"""
#         try:
#             # Import ViT-UNet components
#             import sys
#             import os
            
#             script_dir = os.path.dirname(os.path.abspath(__file__))
#             if script_dir not in sys.path:
#                 sys.path.append(script_dir)
            
#             # Try importing from your vis_trust_full.py
#             from vis_trust_full import ViTUNetForInverseProblem
            
#             return ViTUNetForInverseProblem(
#                 pretrained_model_name="google/vit-base-patch16-224", 
#                 output_size=(224, 224)
#             )
#         except ImportError as e:
#             print(f"Could not import ViT-UNet from vis_trust_full.py: {e}")
#             print("Please ensure vis_trust_full.py is in your Python path")
#             raise
    
#     def _load_restormer(self):
#         """Load Restormer model (which appears to be TransUNet in your case)"""
#         try:
#             # Import Restormer/TransUNet components
#             import sys
#             import os
            
#             script_dir = os.path.dirname(os.path.abspath(__file__))
#             if script_dir not in sys.path:
#                 sys.path.append(script_dir)
            
#             # Try importing from your vis_restormer_full.py (which seems to contain TransUNet)
#             from vis_restormer_full import TransUNet
            
#             return TransUNet(
#                 img_size=224, 
#                 patch_size=16, 
#                 in_channels=3, 
#                 out_channels=3,
#                 embed_dim=768, 
#                 depth=12, 
#                 num_heads=12, 
#                 mlp_ratio=4, 
#                 dropout=0.0
#             )
#         except ImportError as e:
#             print(f"Could not import Restormer from vis_restormer_full.py: {e}")
#             print("Please ensure vis_restormer_full.py is in your Python path")
#             raise
    
#     def evaluate(self, test_dir, save_dir, num_samples=None, batch_size=8, 
#                 visualize_samples=20, seed=42):
#         """
#         Run comprehensive evaluation on the model.
        
#         Returns:
#             dict: Evaluation results with mean ± std for all metrics
#         """
#         print(f"\nStarting {self.model_type} evaluation")
#         print(f"Test data: {test_dir}")
#         print(f"Results will be saved to: {save_dir}")
        
#         # Create directories
#         os.makedirs(save_dir, exist_ok=True)
#         os.makedirs(os.path.join(save_dir, "visualizations"), exist_ok=True)
        
#         # Create dataset
#         dataset = PatchwiseOrthonormalDataset(test_dir, seed=seed, verbose=True)
        
#         # Limit samples if specified
#         if num_samples is not None and num_samples < len(dataset):
#             indices = torch.randperm(len(dataset))[:num_samples].tolist()
#             dataset.image_files = [dataset.image_files[i] for i in indices]
#             print(f"Limited evaluation to {num_samples} samples")
        
#         # Create dataloader
#         dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
#                                num_workers=4, pin_memory=True, drop_last=False)
        
#         # Initialize metrics storage
#         metrics = defaultdict(list)
#         all_predictions = []
#         all_targets = []
#         all_inputs = []
        
#         print(f"Processing {len(dataset)} images...")
#         start_time = time.time()
        
#         with torch.no_grad():
#             for batch_idx, (input_batch, target_batch) in enumerate(dataloader):
#                 # Start reconstruction timing (includes all processing steps)
#                 reconstruction_start = time.time()
                
#                 # Move to device
#                 input_batch = input_batch.to(self.device, non_blocking=True)
#                 target_batch = target_batch.to(self.device, non_blocking=True)
                
#                 # Measure pure inference time (just forward pass)
#                 inference_start = time.time()
#                 outputs = self.model(input_batch)
#                 inference_time = time.time() - inference_start
                
#                 # Ensure output shape matches target
#                 if outputs.shape != target_batch.shape:
#                     outputs = F.interpolate(outputs, size=target_batch.shape[-2:],
#                                           mode='bilinear', align_corners=True)
                
#                 # Clamp outputs to [0, 1]
#                 outputs = torch.clamp(outputs, 0, 1)
                
#                 # Complete reconstruction timing
#                 reconstruction_time = time.time() - reconstruction_start
                
#                 # Calculate metrics for each sample in batch
#                 batch_size_actual = input_batch.shape[0]
#                 for i in range(batch_size_actual):
#                     pred = outputs[i]
#                     target = target_batch[i]
#                     input_img = input_batch[i]
                    
#                     # Calculate all metrics
#                     metrics['inference_time'].append(inference_time / batch_size_actual)
#                     metrics['reconstruction_time'].append(reconstruction_time / batch_size_actual)
#                     metrics['psnr'].append(calculate_psnr(pred, target))
#                     metrics['ssim'].append(calculate_ssim(pred, target))
#                     metrics['ms_ssim'].append(calculate_ms_ssim(pred, target))
#                     metrics['mae'].append(calculate_mae(pred, target))
#                     metrics['mse'].append(F.mse_loss(pred, target).item())
#                     metrics['rmse'].append(calculate_rmse(pred, target))
                    
#                     # Store samples for visualization
#                     if len(all_predictions) < visualize_samples:
#                         all_predictions.append(pred.cpu())
#                         all_targets.append(target.cpu())
#                         all_inputs.append(input_img.cpu())
                
#                 # Print progress
#                 if (batch_idx + 1) % 10 == 0:
#                     processed = min((batch_idx + 1) * batch_size, len(dataset))
#                     print(f"Processed {processed}/{len(dataset)} samples")
        
#         eval_time = time.time() - start_time
#         print(f"Evaluation completed in {eval_time:.2f} seconds")
        
#         # Calculate statistics (mean ± std)
#         results = self._calculate_statistics(metrics)
#         results.update({
#             'model_type': self.model_type,
#             'model_path': self.model_path,
#             'total_samples': len(dataset),
#             'evaluation_time': eval_time,
#             'model_parameters': sum(p.numel() for p in self.model.parameters())
#         })
        
#         # Print results
#         self._print_results(results)
        
#         # Save results
#         self._save_results(results, metrics, save_dir)
        
#         # Create visualizations
#         self._create_visualizations(all_inputs, all_predictions, all_targets,
#                                    save_dir, results)
        
#         return results
    
#     def _calculate_statistics(self, metrics):
#         """Calculate mean ± std for all metrics"""
#         stats = {}
#         for metric_name, values in metrics.items():
#             values = np.array(values)
#             stats[f'{metric_name}_mean'] = float(np.mean(values))
#             stats[f'{metric_name}_std'] = float(np.std(values, ddof=1) if len(values) > 1 else 0)
#             stats[f'{metric_name}_min'] = float(np.min(values))
#             stats[f'{metric_name}_max'] = float(np.max(values))
#             stats[f'{metric_name}_median'] = float(np.median(values))
#         return stats
    
#     def _print_results(self, results):
#         """Print evaluation results in a formatted way"""
#         print(f"\n{'='*60}")
#         print(f"{self.model_type.upper()} EVALUATION RESULTS")
#         print(f"{'='*60}")
#         print(f"Samples: {results['total_samples']}")
#         print(f"Parameters: {results['model_parameters']:,}")
#         print(f"Total time: {results['evaluation_time']:.2f}s")
#         print()
        
#         # Print metrics with mean ± std format
#         metrics_to_show = ['inference_time', 'reconstruction_time', 'psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'rmse']
#         for metric in metrics_to_show:
#             mean_key = f'{metric}_mean'
#             std_key = f'{metric}_std'
#             if mean_key in results and std_key in results:
#                 if 'time' in metric:
#                     print(f"{metric.upper():16}: {results[mean_key]:.6f} ± {results[std_key]:.6f} seconds")
#                 else:
#                     print(f"{metric.upper():16}: {results[mean_key]:.6f} ± {results[std_key]:.6f}")
    
#     def _save_results(self, results, raw_metrics, save_dir):
#         """Save results to files"""
#         # Save summary results
#         with open(os.path.join(save_dir, 'evaluation_results.json'), 'w') as f:
#             json.dump(results, f, indent=2)
        
#         # Save detailed metrics
#         detailed_metrics = []
#         n_samples = len(raw_metrics['psnr'])
#         for i in range(n_samples):
#             row = {'sample_id': i + 1}
#             for metric, values in raw_metrics.items():
#                 row[metric] = values[i]
#             detailed_metrics.append(row)
        
#         # Save to CSV
#         with open(os.path.join(save_dir, 'detailed_metrics.csv'), 'w', newline='') as f:
#             if detailed_metrics:
#                 writer = csv.DictWriter(f, fieldnames=detailed_metrics[0].keys())
#                 writer.writeheader()
#                 writer.writerows(detailed_metrics)
        
#         print(f"Results saved to {save_dir}")
    
#     def _create_visualizations(self, inputs, predictions, targets, save_dir, results):
#         """Create visualization plots"""
#         n_samples = min(len(predictions), 20)
        
#         # Create comparison grid
#         fig, axes = plt.subplots(3, n_samples, figsize=(3*n_samples, 9))
#         if n_samples == 1:
#             axes = axes.reshape(-1, 1)
        
#         for i in range(n_samples):
#             # Convert tensors to numpy
#             input_img = inputs[i].numpy().transpose(1, 2, 0)
#             pred_img = predictions[i].numpy().transpose(1, 2, 0)
#             target_img = targets[i].numpy().transpose(1, 2, 0)
            
#             # Clip to [0, 1]
#             input_img = np.clip(input_img, 0, 1)
#             pred_img = np.clip(pred_img, 0, 1)  
#             target_img = np.clip(target_img, 0, 1)
            
#             # Plot images
#             axes[0, i].imshow(input_img)
#             axes[0, i].set_title(f"Input {i+1}")
#             axes[0, i].axis('off')
            
#             axes[1, i].imshow(pred_img)
#             axes[1, i].set_title(f"Prediction {i+1}")
#             axes[1, i].axis('off')
            
#             axes[2, i].imshow(target_img)
#             axes[2, i].set_title(f"Target {i+1}")
#             axes[2, i].axis('off')
        
#         # Add row labels
#         axes[0, 0].set_ylabel("Input\n(Transformed)", rotation=90, size='large')
#         axes[1, 0].set_ylabel(f"{self.model_type}\nPrediction", rotation=90, size='large')
#         axes[2, 0].set_ylabel("Target\n(Original)", rotation=90, size='large')
        
#         plt.suptitle(f'{self.model_type} Reconstruction Results', fontsize=16)
#         plt.tight_layout()
#         plt.savefig(os.path.join(save_dir, "visualizations", "comparison_grid.png"),
#                     dpi=150, bbox_inches='tight')
#         plt.close()
        
#         # Create metrics distribution plot
#         self._create_metrics_plot(save_dir, results)
    
#     def _create_metrics_plot(self, save_dir, results):
#         """Create metrics distribution visualization"""
#         metrics = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'inference_time', 'reconstruction_time']
#         means = [results[f'{m}_mean'] for m in metrics]
#         stds = [results[f'{m}_std'] for m in metrics]
        
#         fig, ax = plt.subplots(figsize=(14, 8))
#         x_pos = np.arange(len(metrics))
        
#         bars = ax.bar(x_pos, means, yerr=stds, capsize=5, alpha=0.7)
        
#         # Color code bars (higher is better for PSNR/SSIM, lower for others)
#         colors = ['green', 'green', 'green', 'red', 'red', 'red', 'red']
#         for bar, color in zip(bars, colors):
#             bar.set_color(color)
        
#         ax.set_xlabel('Metrics')
#         ax.set_ylabel('Values')
#         ax.set_title(f'{self.model_type} Performance Metrics (Mean ± Std)')
#         ax.set_xticks(x_pos)
#         ax.set_xticklabels([m.upper().replace('_', ' ') for m in metrics], rotation=45)
#         ax.grid(True, alpha=0.3)
        
#         # Add value annotations
#         for i, (mean, std) in enumerate(zip(means, stds)):
#             ax.text(i, mean + std + max(means) * 0.01, 
#                    f'{mean:.3f}±{std:.3f}', 
#                    ha='center', va='bottom', fontsize=8)
        
#         plt.tight_layout()
#         plt.savefig(os.path.join(save_dir, "visualizations", "metrics_summary.png"),
#                     dpi=150, bbox_inches='tight')
#         plt.close()

# # ============================================================================
# # MULTI-MODEL COMPARISON
# # ============================================================================

# class MultiModelComparator:
#     """
#     Compare multiple models and generate comprehensive comparison reports.
#     """
    
#     def __init__(self):
#         self.results = []
    
#     def add_model_results(self, results):
#         """Add evaluation results for a model"""
#         self.results.append(results)
    
#     def compare_models(self, save_dir):
#         """Generate comprehensive comparison between models"""
#         os.makedirs(save_dir, exist_ok=True)
        
#         if len(self.results) < 2:
#             print("Need at least 2 models for comparison")
#             return
        
#         print(f"\nGenerating comparison report for {len(self.results)} models...")
        
#         # Create comparison DataFrame-like structure
#         comparison_data = []
#         for result in self.results:
#             row = {'Model': result['model_type']}
#             metrics = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'inference_time', 'reconstruction_time']
            
#             for metric in metrics:
#                 mean_key = f'{metric}_mean'
#                 std_key = f'{metric}_std'
#                 if mean_key in result and std_key in result:
#                     row[f'{metric}_mean'] = result[mean_key]
#                     row[f'{metric}_std'] = result[std_key]
            
#             row['parameters'] = result['model_parameters']
#             row['samples'] = result['total_samples']
#             comparison_data.append(row)
        
#         # Save comparison data
#         self._save_comparison_data(comparison_data, save_dir)
        
#         # Create comparison visualizations
#         self._create_comparison_plots(comparison_data, save_dir)
        
#         # Generate comparison report
#         self._generate_comparison_report(comparison_data, save_dir)
        
#         print(f"Comparison results saved to {save_dir}")
    
#     def _save_comparison_data(self, data, save_dir):
#         """Save comparison data to CSV"""
#         if not data:
#             return
        
#         with open(os.path.join(save_dir, 'model_comparison.csv'), 'w', newline='') as f:
#             writer = csv.DictWriter(f, fieldnames=data[0].keys())
#             writer.writeheader()
#             writer.writerows(data)
    
#     def _create_comparison_plots(self, data, save_dir):
#         """Create comparison visualization plots"""
#         if not data:
#             return
        
#         models = [row['Model'] for row in data]
#         metrics = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'inference_time', 'reconstruction_time']
        
#         # Create subplot for each metric
#         fig, axes = plt.subplots(3, 3, figsize=(20, 15))
#         axes = axes.flatten()
        
#         for i, metric in enumerate(metrics):
#             if i >= len(axes):
#                 break
                
#             means = [row[f'{metric}_mean'] for row in data if f'{metric}_mean' in row]
#             stds = [row[f'{metric}_std'] for row in data if f'{metric}_std' in row]
            
#             if not means:
#                 continue
            
#             bars = axes[i].bar(models, means, yerr=stds, capsize=5, alpha=0.7)
            
#             # Color coding: green for "higher is better", red for "lower is better"
#             color = 'green' if metric in ['psnr', 'ssim', 'ms_ssim'] else 'red'
#             for bar in bars:
#                 bar.set_color(color)
            
#             metric_title = metric.upper().replace('_', ' ')
#             axes[i].set_title(f'{metric_title} Comparison')
#             axes[i].set_ylabel(f'{metric_title}')
#             axes[i].tick_params(axis='x', rotation=45)
#             axes[i].grid(True, alpha=0.3)
            
#             # Add value labels
#             for j, (mean, std) in enumerate(zip(means, stds)):
#                 axes[i].text(j, mean + std + max(means) * 0.02,
#                            f'{mean:.3f}±{std:.3f}',
#                            ha='center', va='bottom', fontsize=8)
        
#         # Hide unused subplots
#         for i in range(len(metrics), len(axes)):
#             axes[i].set_visible(False)
        
#         plt.suptitle('Model Performance Comparison (Mean ± Std)', fontsize=16)
#         plt.tight_layout()
#         plt.savefig(os.path.join(save_dir, 'comparison_plots.png'),
#                     dpi=300, bbox_inches='tight')
#         plt.close()
        
#         # Create separate timing comparison plot
#         self._create_timing_comparison_plot(data, save_dir)
        
#         # Create ranking plot
#         self._create_ranking_plot(data, save_dir)
    
#     def _create_timing_comparison_plot(self, data, save_dir):
#         """Create dedicated timing comparison plot"""
#         models = [row['Model'] for row in data]
        
#         # Extract timing data
#         inference_means = [row['inference_time_mean'] for row in data if 'inference_time_mean' in row]
#         inference_stds = [row['inference_time_std'] for row in data if 'inference_time_std' in row]
#         recon_means = [row['reconstruction_time_mean'] for row in data if 'reconstruction_time_mean' in row]
#         recon_stds = [row['reconstruction_time_std'] for row in data if 'reconstruction_time_std' in row]
        
#         if not inference_means or not recon_means:
#             return
        
#         x = np.arange(len(models))
#         width = 0.35
        
#         fig, ax = plt.subplots(figsize=(12, 8))
        
#         bars1 = ax.bar(x - width/2, inference_means, width, yerr=inference_stds, 
#                        label='Inference Time', capsize=5, alpha=0.8, color='lightcoral')
#         bars2 = ax.bar(x + width/2, recon_means, width, yerr=recon_stds,
#                        label='Total Reconstruction Time', capsize=5, alpha=0.8, color='lightblue')
        
#         ax.set_xlabel('Models')
#         ax.set_ylabel('Time (seconds)')
#         ax.set_title('Inference vs Total Reconstruction Time Comparison')
#         ax.set_xticks(x)
#         ax.set_xticklabels(models, rotation=45)
#         ax.legend()
#         ax.grid(True, alpha=0.3)
        
#         # Add value labels
#         for i, (inf_mean, inf_std, rec_mean, rec_std) in enumerate(zip(inference_means, inference_stds, recon_means, recon_stds)):
#             ax.text(i - width/2, inf_mean + inf_std + max(inference_means) * 0.02,
#                    f'{inf_mean:.3f}±{inf_std:.3f}', ha='center', va='bottom', fontsize=8)
#             ax.text(i + width/2, rec_mean + rec_std + max(recon_means) * 0.02,
#                    f'{rec_mean:.3f}±{rec_std:.3f}', ha='center', va='bottom', fontsize=8)
        
#         plt.tight_layout()
#         plt.savefig(os.path.join(save_dir, 'timing_comparison.png'),
#                     dpi=300, bbox_inches='tight')
#         plt.close()
    
#     def _create_ranking_plot(self, data, save_dir):
#         """Create model ranking visualization"""
#         models = [row['Model'] for row in data]
        
#         # Define ranking criteria (higher is better gets positive score)
#         ranking_metrics = {
#             'PSNR': ([row['psnr_mean'] for row in data], True),
#             'SSIM': ([row['ssim_mean'] for row in data], True),  
#             'MS-SSIM': ([row['ms_ssim_mean'] for row in data], True),
#             'MAE': ([row['mae_mean'] for row in data], False),
#             'MSE': ([row['mse_mean'] for row in data], False),
#             'Inference Speed': ([1/row['inference_time_mean'] for row in data], True),  # Inverse for speed
#             'Reconstruction Speed': ([1/row['reconstruction_time_mean'] for row in data], True)  # Inverse for speed
#         }
        
#         # Calculate normalized scores
#         model_scores = {model: 0 for model in models}
        
#         for metric_name, (values, higher_better) in ranking_metrics.items():
#             # Normalize to 0-1 range
#             min_val, max_val = min(values), max(values)
#             if max_val > min_val:
#                 normalized = [(v - min_val) / (max_val - min_val) for v in values]
#                 if not higher_better:
#                     normalized = [1 - n for n in normalized]  # Flip for "lower is better"
                
#                 for model, score in zip(models, normalized):
#                     model_scores[model] += score
        
#         # Sort by total score
#         sorted_models = sorted(model_scores.items(), key=lambda x: x[1], reverse=True)
        
#         # Plot ranking
#         fig, ax = plt.subplots(figsize=(12, 8))
#         models_sorted = [m[0] for m in sorted_models]
#         scores_sorted = [m[1] for m in sorted_models]
        
#         bars = ax.bar(models_sorted, scores_sorted, alpha=0.7)
        
#         # Color gradient
#         colors = plt.cm.RdYlGn(np.linspace(0.3, 0.9, len(bars)))
#         for bar, color in zip(bars, colors):
#             bar.set_color(color)
        
#         ax.set_title('Overall Model Ranking (Normalized Composite Score)\nIncludes Quality Metrics + Inference & Reconstruction Speed')
#         ax.set_ylabel('Composite Score')
#         ax.tick_params(axis='x', rotation=45)
#         ax.grid(True, alpha=0.3)
        
#         # Add score labels
#         for i, score in enumerate(scores_sorted):
#             ax.text(i, score + max(scores_sorted) * 0.01,
#                    f'{score:.2f}', ha='center', va='bottom')
        
#         plt.tight_layout()
#         plt.savefig(os.path.join(save_dir, 'model_ranking.png'),
#                     dpi=300, bbox_inches='tight')
#         plt.close()
    
#     def _generate_comparison_report(self, data, save_dir):
#         """Generate comprehensive comparison report"""
#         with open(os.path.join(save_dir, 'comparison_report.txt'), 'w') as f:
#             f.write("="*80 + "\n")
#             f.write("COMPREHENSIVE MODEL COMPARISON REPORT\n")
#             f.write("="*80 + "\n\n")
            
#             f.write(f"Models Compared: {len(data)}\n")
#             f.write(f"Comparison Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
#             # Summary table
#             f.write("PERFORMANCE SUMMARY (Mean ± Std):\n")
#             f.write("-" * 50 + "\n")
            
#             metrics_headers = ["Model", "PSNR", "SSIM", "MS-SSIM", "MAE", "MSE", "Inf.Time(s)", "Recon.Time(s)"]
#             f.write(f"{'Model':<15} {'PSNR':<12} {'SSIM':<12} {'MS-SSIM':<12} {'MAE':<12} {'MSE':<12} {'Inf.Time':<12} {'Recon.Time':<12}\n")
#             f.write("-" * 120 + "\n")
            
#             for row in data:
#                 model_name = row['Model'][:14]  # Truncate long names
#                 psnr = f"{row['psnr_mean']:.2f}±{row['psnr_std']:.2f}"
#                 ssim = f"{row['ssim_mean']:.3f}±{row['ssim_std']:.3f}"
#                 ms_ssim = f"{row['ms_ssim_mean']:.3f}±{row['ms_ssim_std']:.3f}"
#                 mae = f"{row['mae_mean']:.3f}±{row['mae_std']:.3f}"
#                 mse = f"{row['mse_mean']:.3f}±{row['mse_std']:.3f}"
#                 inf_time = f"{row['inference_time_mean']:.3f}±{row['inference_time_std']:.3f}"
#                 recon_time = f"{row['reconstruction_time_mean']:.3f}±{row['reconstruction_time_std']:.3f}"
                
#                 f.write(f"{model_name:<15} {psnr:<12} {ssim:<12} {ms_ssim:<12} {mae:<12} {mse:<12} {inf_time:<12} {recon_time:<12}\n")
            
#             f.write("\n\nKEY FINDINGS:\n")
#             f.write("-" * 15 + "\n")
            
#             # Find best performers
#             metrics_for_comparison = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'inference_time', 'reconstruction_time']
#             higher_better = ['psnr', 'ssim', 'ms_ssim']
            
#             for metric in metrics_for_comparison:
#                 mean_key = f'{metric}_mean'
#                 values = [row[mean_key] for row in data if mean_key in row]
#                 models = [row['Model'] for row in data if mean_key in row]
                
#                 if values:
#                     if metric in higher_better:
#                         best_idx = values.index(max(values))
#                         f.write(f"Best {metric.upper().replace('_', ' ')}: {models[best_idx]} ({values[best_idx]:.4f})\n")
#                     else:
#                         best_idx = values.index(min(values))
#                         f.write(f"Best {metric.upper().replace('_', ' ')}: {models[best_idx]} ({values[best_idx]:.4f})\n")
            
#             f.write("\nNOTES:\n")
#             f.write("- Higher PSNR, SSIM, MS-SSIM indicate better quality\n")
#             f.write("- Lower MAE, MSE indicate better accuracy\n")
#             f.write("- Lower inference and reconstruction times indicate faster processing\n")
#             f.write("- Inference time: Pure model forward pass time\n")
#             f.write("- Reconstruction time: Total time including data movement and post-processing\n")
#             f.write("- All values reported as mean ± standard deviation\n")

# # ============================================================================
# # MAIN EVALUATION FUNCTIONS
# # ============================================================================

# def evaluate_single_model(model_type, model_path, test_dir, save_dir, 
#                          num_samples=None, batch_size=8, device='cuda'):
#     """
#     Evaluate a single model and return results.
    
#     Args:
#         model_type: Type of model ('transunet', 'unet', 'trust', 'restormer')
#         model_path: Path to model checkpoint
#         test_dir: Directory containing test images
#         save_dir: Directory to save results
#         num_samples: Number of samples to evaluate (None for all)
#         batch_size: Batch size for evaluation
#         device: Device to run on
    
#     Returns:
#         dict: Evaluation results
#     """
#     evaluator = UnifiedModelEvaluator(model_type, model_path, device)
#     results = evaluator.evaluate(test_dir, save_dir, num_samples, batch_size)
#     return results

# def evaluate_multiple_models(models_config, test_dir, comparison_save_dir,
#                             num_samples=None, batch_size=8, device='cuda'):
#     """
#     Evaluate multiple models and generate comparison.
    
#     Args:
#         models_config: List of dicts with keys 'type', 'path', 'save_dir'
#         test_dir: Directory containing test images  
#         comparison_save_dir: Directory to save comparison results
#         num_samples: Number of samples to evaluate
#         batch_size: Batch size
#         device: Device to run on
    
#     Returns:
#         list: Results for all models
#     """
#     comparator = MultiModelComparator()
#     all_results = []
    
#     print(f"Starting evaluation of {len(models_config)} models...")
    
#     for i, config in enumerate(models_config):
#         print(f"\n[{i+1}/{len(models_config)}] Evaluating {config['type']}...")
        
#         try:
#             results = evaluate_single_model(
#                 model_type=config['type'],
#                 model_path=config['path'],
#                 test_dir=test_dir,
#                 save_dir=config['save_dir'],
#                 num_samples=num_samples,
#                 batch_size=batch_size,
#                 device=device
#             )
            
#             comparator.add_model_results(results)
#             all_results.append(results)
#             print(f"✓ {config['type']} evaluation completed")
            
#         except Exception as e:
#             print(f"✗ {config['type']} evaluation failed: {e}")
#             continue
    
#     # Generate comparison
#     if len(all_results) > 1:
#         print(f"\nGenerating comparison for {len(all_results)} models...")
#         comparator.compare_models(comparison_save_dir)
    
#     return all_results

# # ============================================================================
# # EXAMPLE USAGE
# # ============================================================================

# if __name__ == "__main__":
#     # Example configuration using your model paths
#     models_config = [
#         {
#             'type': 'transunet',
#             'path': r'F:\ImageNet\transunet_16x16_checkpoints\best_model.pth',
#             'save_dir': r'F:\ImageNet\evaluation_results\transunet'
#         },
#         {
#             'type': 'unet', 
#             'path': r'F:\ImageNet\unet_16x16_checkpoints\best_model.pth',
#             'save_dir': r'F:\ImageNet\evaluation_results\unet'
#         },
#         {
#             'type': 'trust',
#             'path': r'F:\ImageNet\trust_16x16_checkpoints\best_model.pth', 
#             'save_dir': r'F:\ImageNet\evaluation_results\trust'
#         },
#         {
#             'type': 'restormer',
#             'path': r'F:\ImageNet\restormer_16x16_checkpoints\best_model.pth',
#             'save_dir': r'F:\ImageNet\evaluation_results\restormer'
#         }
#     ]
    
#     # # Single model evaluation example
#     # print("=== Single Model Evaluation ===")
#     # transunet_results = evaluate_single_model(
#     #     model_type='transunet',
#     #     model_path=r'F:\ImageNet\transunet_16x16_checkpoints\best_model.pth',
#     #     test_dir=r'F:\imgnet\data\test',  # Your test data directory
#     #     save_dir=r'F:\ImageNet\evaluation_results\transunet_single',
#     #     num_samples=50,  # Evaluate on 50 samples
#     #     batch_size=8,
#     #     device='cuda'
#     # )
    
#     # Multi-model comparison example
#     print("\n=== Multi-Model Comparison ===")
#     results = evaluate_multiple_models(
#         models_config=models_config,
#         test_dir=r'F:\imgnet\data\test',  # Your test data directory
#         comparison_save_dir=r'F:\ImageNet\evaluation_results\comparison',
#         num_samples=1000,  # Evaluate on 50 samples
#         batch_size=8,
#         device='cuda'
#     )
    
#     print(f"\nEvaluation completed for {len(results)} models!")
#     print("Check the F:\\ImageNet\\evaluation_results\\ directory for detailed outputs and comparisons.")
    
#     # Print summary of results
#     if results:
#         print("\nQUICK SUMMARY:")
#         print("-" * 50)
#         for result in results:
#             print(f"{result['model_type'].upper():12}: "
#                   f"PSNR={result['psnr_mean']:.2f}±{result['psnr_std']:.2f}, "
#                   f"SSIM={result['ssim_mean']:.3f}±{result['ssim_std']:.3f}, "
#                   f"Time={result['inference_time_mean']:.3f}±{result['inference_time_std']:.3f}s")




#!/usr/bin/env python3
"""
Unified Model Evaluation Script
==============================

This script provides comprehensive evaluation capabilities for various image reconstruction models
including TransUNet, U-Net, ViT-UNet (TRUST), and Restormer on orthonormal inverse problems.

Features:
- Unified evaluation interface for all model types
- Comprehensive metrics (PSNR, SSIM, MAE, MSE, MS-SSIM)
- Statistical analysis with mean ± std reporting
- Visualization and comparison capabilities
- Individual and comparative result generation
"""

import os
import time
import json
import csv
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image

# Import metrics
from pytorch_msssim import ssim, ms_ssim

def calculate_ms_ssim(pred, target):
    """Calculate MS-SSIM using pytorch_msssim"""
    try:
        return ms_ssim(pred.unsqueeze(0), target.unsqueeze(0),
                       data_range=1.0, size_average=True).item()
    except:
        # Fallback to regular SSIM if MS-SSIM fails
        return calculate_ssim(pred, target)

# ============================================================================
# SHARED DATASET AND UTILITIES
# ============================================================================

def make_gaussian_random_orthonormal_rows(h=256, w=256, seed=42):
    """Generate orthonormal matrix for patch transformation"""
    if seed is not None:
        torch.manual_seed(seed)
    A = torch.randn(h, w)
    Q, R = torch.linalg.qr(A.T)
    return Q.T

class PatchwiseOrthonormalDataset:
    """
    Unified dataset for orthonormal transformation evaluation.
    Handles 16x16 patch-wise orthonormal transformation for 224x224 images.
    """
    def __init__(self, data_dir, seed=42, verbose=False):
        self.data_dir = data_dir
        self.A = make_gaussian_random_orthonormal_rows(h=256, w=256, seed=seed)
        
        self.data_path = Path(data_dir)
        if not self.data_path.exists():
            raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPEG', '.JPG'}
        self.image_files = [f for f in self.data_path.iterdir() 
                           if f.is_file() and f.suffix in image_extensions]
        
        if len(self.image_files) == 0:
            raise ValueError(f"No images found in {data_dir}")
        
        if verbose:
            print(f"Loaded {len(self.image_files)} images from {data_dir}")
            print(f"Using 16x16 patch-wise orthonormal transformation with matrix shape: {self.A.shape}")

    def __len__(self):
        return len(self.image_files)

    def resize_min_side(self, img, min_side=224):
        w, h = img.size
        s = min_side / min(w, h)
        return img.resize((int(round(w*s)), int(round(h*s))), Image.Resampling.LANCZOS)

    def center_crop(self, img, size=224):
        w, h = img.size
        left = (w - size) // 2
        top = (h - size) // 2
        return img.crop((left, top, left + size, top + size))

    def preprocess_image(self, img):
        img = img.convert("RGB")
        img_resized = self.resize_min_side(img, 224)
        img_crop = self.center_crop(img_resized, 224)
        x = np.array(img_crop).astype(np.float32) / 255.0
        return x

    def process_image_with_orthonormal_masks(self, np_img, mask_matrix):
        img_tensor = torch.from_numpy(np_img).float()
        if img_tensor.shape[2] == 3:
            img_gray = img_tensor.mean(dim=2)
        else:
            img_gray = img_tensor
        
        patches = img_gray.unfold(0, 16, 16).unfold(1, 16, 16)
        transformed_patches = torch.zeros(14, 14, 256)
        
        for i in range(14):
            for j in range(14):
                patch_flat = patches[i, j].flatten()
                transformed = mask_matrix @ patch_flat
                transformed_patches[i, j] = transformed
        
        return transformed_patches

    def reconstruct_masked_image(self, transformed_patches):
        masked_image = torch.zeros(224, 224)
        for i in range(14):
            for j in range(14):
                transformed_patch = transformed_patches[i, j]
                patch_16x16 = transformed_patch.reshape(16, 16)
                start_h = i * 16
                end_h = start_h + 16
                start_w = j * 16
                end_w = start_w + 16
                masked_image[start_h:end_h, start_w:end_w] = patch_16x16
        return masked_image

    def apply_patchwise_orthonormal_transform(self, x):
        y_channels = []
        for c in range(3):
            single_channel = x[..., c]
            transformed_patches = self.process_image_with_orthonormal_masks(
                np.expand_dims(single_channel, axis=2), self.A
            )
            masked_channel = self.reconstruct_masked_image(transformed_patches)
            y_channels.append(masked_channel.numpy())
        
        y = np.stack(y_channels, axis=2)
        y_min = y.min()
        y_max = y.max()
        y_norm = (y - y_min) / (y_max - y_min + 1e-8)
        return y_norm

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        
        x = self.preprocess_image(img)
        y = self.apply_patchwise_orthonormal_transform(x)
        
        x_tensor = torch.from_numpy(x).permute(2, 0, 1)
        y_tensor = torch.from_numpy(y).permute(2, 0, 1)
        
        return y_tensor, x_tensor

# ============================================================================
# METRIC CALCULATION FUNCTIONS
# ============================================================================

def calculate_psnr(pred, target, max_val=1.0):
    """Calculate Peak Signal-to-Noise Ratio"""
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
    return psnr.item()

def calculate_mae(pred, target):
    """Calculate Mean Absolute Error"""
    return torch.mean(torch.abs(pred - target)).item()

def calculate_rmse(pred, target):
    """Calculate Root Mean Square Error"""
    return torch.sqrt(torch.mean((pred - target) ** 2)).item()

def calculate_ssim(pred, target):
    """Calculate SSIM using pytorch_msssim"""
    return ssim(pred.unsqueeze(0), target.unsqueeze(0), 
                data_range=1.0, size_average=True).item()

def calculate_fpr_score(pred, target, t_high=0.5, t_low=0.2):
    """
    Compute False Positive Regions (FPR) hallucination score
    
    Parameters
    ----------
    pred : torch.Tensor
        Generated or reconstructed image
    target : torch.Tensor
        Ground truth image
    t_high : float
        Upper threshold for generated image
    t_low : float
        Lower threshold for ground truth
        
    Returns
    -------
    score : float
        Hallucination score (fraction of hallucinated pixels)
    mask : numpy.ndarray
        Binary hallucination mask
    """
    # Convert to numpy and ensure single channel for FPR calculation
    pred_np = pred.cpu().numpy()
    target_np = target.cpu().numpy()
    
    # Take first channel if RGB
    if len(pred_np.shape) == 3:
        pred_np = pred_np[0]  # Take first channel
        target_np = target_np[0]
    
    # Ensure both inputs are normalized to [0,1] range
    pred_norm = (pred_np - pred_np.min()) / (pred_np.max() - pred_np.min() + 1e-8)
    target_norm = (target_np - target_np.min()) / (target_np.max() - target_np.min() + 1e-8)
    
    # Define hallucination mask H = (pred > t_high) ∧ (target ≤ t_low)
    H = np.logical_and(pred_norm > t_high, target_norm <= t_low)
    
    # Compute hallucination score as fraction of hallucinated pixels
    hallucination_score = np.sum(H) / H.size
    
    return hallucination_score, H

# ============================================================================
# UNIFIED MODEL EVALUATION CLASS
# ============================================================================

class UnifiedModelEvaluator:
    """
    Unified evaluator for all model types with consistent metric calculation.
    """
    
    def __init__(self, model_type, model_path, device='cuda', fpr_t_high=0.5, fpr_t_low=0.2):
        self.model_type = model_type
        self.model_path = model_path
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.model = None
        self.fpr_t_high = fpr_t_high
        self.fpr_t_low = fpr_t_low
        
        print(f"Initializing {model_type} evaluator on {self.device}")
        print(f"FPR thresholds: t_high={fpr_t_high}, t_low={fpr_t_low}")
        self._load_model()
    
    def _load_model(self):
        """Load the appropriate model based on model_type"""
        print(f"Loading {self.model_type} model from {self.model_path}")
        
        try:
            if self.model_type.lower() == 'transunet':
                self.model = self._load_transunet()
            elif self.model_type.lower() == 'unet':
                self.model = self._load_unet()
            elif self.model_type.lower() in ['trust', 'vit_unet']:
                self.model = self._load_vit_unet()
            elif self.model_type.lower() == 'restormer':
                self.model = self._load_restormer()
            else:
                raise ValueError(f"Unsupported model type: {self.model_type}")
            
            # Load weights and set to evaluation mode
            checkpoint = torch.load(self.model_path, map_location=self.device)
            if 'model_state_dict' in checkpoint:
                self.model.load_state_dict(checkpoint['model_state_dict'])
            else:
                self.model.load_state_dict(checkpoint)
            
            self.model.eval()
            self.model = self.model.to(self.device)
            
            # Count parameters
            total_params = sum(p.numel() for p in self.model.parameters())
            print(f"✓ {self.model_type} model loaded successfully ({total_params:,} parameters)")
            
        except Exception as e:
            print(f"✗ Error loading {self.model_type} model: {e}")
            raise
    
    def _load_transunet(self):
        """Load TransUNet model"""
        try:
            # Import TransUNet components (you may need to adjust paths)
            import sys
            import os
            
            # Add the directory containing your vis_transunet_full.py to path if needed
            script_dir = os.path.dirname(os.path.abspath(__file__))
            if script_dir not in sys.path:
                sys.path.append(script_dir)
            
            # Try importing from your vis_transunet_full.py
            from vis_transunet_full import TransUNet
            
            return TransUNet(
                img_size=224, 
                patch_size=16, 
                in_channels=3, 
                out_channels=3,
                embed_dim=768, 
                depth=12, 
                num_heads=12, 
                mlp_ratio=4, 
                dropout=0.0
            )
        except ImportError as e:
            print(f"Could not import TransUNet from vis_transunet_full.py: {e}")
            print("Please ensure vis_transunet_full.py is in your Python path")
            raise
    
    def _load_unet(self):
        """Load U-Net model"""
        try:
            # Import U-Net components
            import sys
            import os
            
            script_dir = os.path.dirname(os.path.abspath(__file__))
            if script_dir not in sys.path:
                sys.path.append(script_dir)
            
            # Try importing from your vis_unet_full.py
            from vis_unet_full import UNetForInverseProblem
            
            return UNetForInverseProblem(
                n_channels=3, 
                n_classes=3, 
                bilinear=True
            )
        except ImportError as e:
            print(f"Could not import UNet from vis_unet_full.py: {e}")
            print("Please ensure vis_unet_full.py is in your Python path")
            raise
    
    def _load_vit_unet(self):
        """Load ViT-UNet (TRUST) model"""
        try:
            # Import ViT-UNet components
            import sys
            import os
            
            script_dir = os.path.dirname(os.path.abspath(__file__))
            if script_dir not in sys.path:
                sys.path.append(script_dir)
            
            # Try importing from your vis_trust_full.py
            from vis_trust_full import ViTUNetForInverseProblem
            
            return ViTUNetForInverseProblem(
                pretrained_model_name="google/vit-base-patch16-224", 
                output_size=(224, 224)
            )
        except ImportError as e:
            print(f"Could not import ViT-UNet from vis_trust_full.py: {e}")
            print("Please ensure vis_trust_full.py is in your Python path")
            raise
    
    def _load_restormer(self):
        """Load Restormer model (which appears to be TransUNet in your case)"""
        try:
            # Import Restormer/TransUNet components
            import sys
            import os
            
            script_dir = os.path.dirname(os.path.abspath(__file__))
            if script_dir not in sys.path:
                sys.path.append(script_dir)
            
            # Try importing from your vis_restormer_full.py (which seems to contain TransUNet)
            from vis_restormer_full import TransUNet
            
            return TransUNet(
                img_size=224, 
                patch_size=16, 
                in_channels=3, 
                out_channels=3,
                embed_dim=768, 
                depth=12, 
                num_heads=12, 
                mlp_ratio=4, 
                dropout=0.0
            )
        except ImportError as e:
            print(f"Could not import Restormer from vis_restormer_full.py: {e}")
            print("Please ensure vis_restormer_full.py is in your Python path")
            raise
    
    def evaluate(self, test_dir, save_dir, num_samples=None, batch_size=8, 
                visualize_samples=20, seed=42):
        """
        Run comprehensive evaluation on the model.
        
        Returns:
            dict: Evaluation results with mean ± std for all metrics
        """
        print(f"\nStarting {self.model_type} evaluation")
        print(f"Test data: {test_dir}")
        print(f"Results will be saved to: {save_dir}")
        
        # Create directories
        os.makedirs(save_dir, exist_ok=True)
        os.makedirs(os.path.join(save_dir, "visualizations"), exist_ok=True)
        
        # Create dataset
        dataset = PatchwiseOrthonormalDataset(test_dir, seed=seed, verbose=True)
        
        # Limit samples if specified
        if num_samples is not None and num_samples < len(dataset):
            indices = torch.randperm(len(dataset))[:num_samples].tolist()
            dataset.image_files = [dataset.image_files[i] for i in indices]
            print(f"Limited evaluation to {num_samples} samples")
        
        # Create dataloader
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                               num_workers=4, pin_memory=True, drop_last=False)
        
        # Initialize metrics storage
        metrics = defaultdict(list)
        all_predictions = []
        all_targets = []
        all_inputs = []
        
        print(f"Processing {len(dataset)} images...")
        start_time = time.time()
        
        with torch.no_grad():
            for batch_idx, (input_batch, target_batch) in enumerate(dataloader):
                # Start reconstruction timing (includes all processing steps)
                reconstruction_start = time.time()
                
                # Move to device
                input_batch = input_batch.to(self.device, non_blocking=True)
                target_batch = target_batch.to(self.device, non_blocking=True)
                
                # Measure pure inference time (just forward pass)
                inference_start = time.time()
                outputs = self.model(input_batch)
                inference_time = time.time() - inference_start
                
                # Ensure output shape matches target
                if outputs.shape != target_batch.shape:
                    outputs = F.interpolate(outputs, size=target_batch.shape[-2:],
                                          mode='bilinear', align_corners=True)
                
                # Clamp outputs to [0, 1]
                outputs = torch.clamp(outputs, 0, 1)
                
                # Complete reconstruction timing
                reconstruction_time = time.time() - reconstruction_start
                
                # Calculate metrics for each sample in batch
                batch_size_actual = input_batch.shape[0]
                for i in range(batch_size_actual):
                    pred = outputs[i]
                    target = target_batch[i]
                    input_img = input_batch[i]
                    
                    # Calculate all metrics
                    metrics['inference_time'].append(inference_time / batch_size_actual)
                    metrics['reconstruction_time'].append(reconstruction_time / batch_size_actual)
                    metrics['psnr'].append(calculate_psnr(pred, target))
                    metrics['ssim'].append(calculate_ssim(pred, target))
                    metrics['ms_ssim'].append(calculate_ms_ssim(pred, target))
                    metrics['mae'].append(calculate_mae(pred, target))
                    metrics['mse'].append(F.mse_loss(pred, target).item())
                    metrics['rmse'].append(calculate_rmse(pred, target))
                    
                    # Calculate FPR (False Positive Rate) for hallucination detection
                    fpr_score, fpr_mask = calculate_fpr_score(pred, target, self.fpr_t_high, self.fpr_t_low)
                    metrics['fpr'].append(fpr_score)
                    
                    # Store samples for visualization
                    if len(all_predictions) < visualize_samples:
                        all_predictions.append(pred.cpu())
                        all_targets.append(target.cpu())
                        all_inputs.append(input_img.cpu())
                
                # Print progress
                if (batch_idx + 1) % 10 == 0:
                    processed = min((batch_idx + 1) * batch_size, len(dataset))
                    print(f"Processed {processed}/{len(dataset)} samples")
        
        eval_time = time.time() - start_time
        print(f"Evaluation completed in {eval_time:.2f} seconds")
        
        # Calculate statistics (mean ± std)
        results = self._calculate_statistics(metrics)
        results.update({
            'model_type': self.model_type,
            'model_path': self.model_path,
            'total_samples': len(dataset),
            'evaluation_time': eval_time,
            'model_parameters': sum(p.numel() for p in self.model.parameters())
        })
        
        # Print results
        self._print_results(results)
        
        # Save results
        self._save_results(results, metrics, save_dir)
        
        # Create visualizations
        self._create_visualizations(all_inputs, all_predictions, all_targets,
                                   save_dir, results)
        
        return results
    
    def _calculate_statistics(self, metrics):
        """Calculate mean ± std for all metrics"""
        stats = {}
        for metric_name, values in metrics.items():
            values = np.array(values)
            stats[f'{metric_name}_mean'] = float(np.mean(values))
            stats[f'{metric_name}_std'] = float(np.std(values, ddof=1) if len(values) > 1 else 0)
            stats[f'{metric_name}_min'] = float(np.min(values))
            stats[f'{metric_name}_max'] = float(np.max(values))
            stats[f'{metric_name}_median'] = float(np.median(values))
        return stats
    
    def _print_results(self, results):
        """Print evaluation results in a formatted way"""
        print(f"\n{'='*60}")
        print(f"{self.model_type.upper()} EVALUATION RESULTS")
        print(f"{'='*60}")
        print(f"Samples: {results['total_samples']}")
        print(f"Parameters: {results['model_parameters']:,}")
        print(f"Total time: {results['evaluation_time']:.2f}s")
        print()
        
        # Print metrics with mean ± std format
        metrics_to_show = ['inference_time', 'reconstruction_time', 'psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'rmse', 'fpr']
        for metric in metrics_to_show:
            mean_key = f'{metric}_mean'
            std_key = f'{metric}_std'
            if mean_key in results and std_key in results:
                if 'time' in metric:
                    print(f"{metric.upper():16}: {results[mean_key]:.6f} ± {results[std_key]:.6f} seconds")
                else:
                    print(f"{metric.upper():16}: {results[mean_key]:.6f} ± {results[std_key]:.6f}")
    
    def _save_results(self, results, raw_metrics, save_dir):
        """Save results to files"""
        # Save summary results
        with open(os.path.join(save_dir, 'evaluation_results.json'), 'w') as f:
            json.dump(results, f, indent=2)
        
        # Save detailed metrics
        detailed_metrics = []
        n_samples = len(raw_metrics['psnr'])
        for i in range(n_samples):
            row = {'sample_id': i + 1}
            for metric, values in raw_metrics.items():
                row[metric] = values[i]
            detailed_metrics.append(row)
        
        # Save to CSV
        with open(os.path.join(save_dir, 'detailed_metrics.csv'), 'w', newline='') as f:
            if detailed_metrics:
                writer = csv.DictWriter(f, fieldnames=detailed_metrics[0].keys())
                writer.writeheader()
                writer.writerows(detailed_metrics)
        
        print(f"Results saved to {save_dir}")
    
    def _create_visualizations(self, inputs, predictions, targets, save_dir, results):
        """Create visualization plots"""
        n_samples = min(len(predictions), 20)
        
        # Create comparison grid
        fig, axes = plt.subplots(3, n_samples, figsize=(3*n_samples, 9))
        if n_samples == 1:
            axes = axes.reshape(-1, 1)
        
        for i in range(n_samples):
            # Convert tensors to numpy
            input_img = inputs[i].numpy().transpose(1, 2, 0)
            pred_img = predictions[i].numpy().transpose(1, 2, 0)
            target_img = targets[i].numpy().transpose(1, 2, 0)
            
            # Clip to [0, 1]
            input_img = np.clip(input_img, 0, 1)
            pred_img = np.clip(pred_img, 0, 1)  
            target_img = np.clip(target_img, 0, 1)
            
            # Plot images
            axes[0, i].imshow(input_img)
            axes[0, i].set_title(f"Input {i+1}")
            axes[0, i].axis('off')
            
            axes[1, i].imshow(pred_img)
            axes[1, i].set_title(f"Prediction {i+1}")
            axes[1, i].axis('off')
            
            axes[2, i].imshow(target_img)
            axes[2, i].set_title(f"Target {i+1}")
            axes[2, i].axis('off')
        
        # Add row labels
        axes[0, 0].set_ylabel("Input\n(Transformed)", rotation=90, size='large')
        axes[1, 0].set_ylabel(f"{self.model_type}\nPrediction", rotation=90, size='large')
        axes[2, 0].set_ylabel("Target\n(Original)", rotation=90, size='large')
        
        plt.suptitle(f'{self.model_type} Reconstruction Results', fontsize=16)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, "visualizations", "comparison_grid.png"),
                    dpi=150, bbox_inches='tight')
        plt.close()
        
        # Create metrics distribution plot
        self._create_metrics_plot(save_dir, results)
    
    def _create_metrics_plot(self, save_dir, results):
        """Create metrics distribution visualization"""
        metrics = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'rmse', 'fpr', 'inference_time', 'reconstruction_time']
        means = [results[f'{m}_mean'] for m in metrics]
        stds = [results[f'{m}_std'] for m in metrics]
        
        fig, ax = plt.subplots(figsize=(16, 8))
        x_pos = np.arange(len(metrics))
        
        bars = ax.bar(x_pos, means, yerr=stds, capsize=5, alpha=0.7)
        
        # Color code bars (higher is better for PSNR/SSIM, lower for others)
        colors = ['green', 'green', 'green', 'red', 'red', 'red', 'red', 'red', 'red']
        for bar, color in zip(bars, colors):
            bar.set_color(color)
        
        ax.set_xlabel('Metrics')
        ax.set_ylabel('Values')
        ax.set_title(f'{self.model_type} Performance Metrics (Mean ± Std)')
        ax.set_xticks(x_pos)
        ax.set_xticklabels([m.upper().replace('_', ' ') for m in metrics], rotation=45)
        ax.grid(True, alpha=0.3)
        
        # Add value annotations
        for i, (mean, std) in enumerate(zip(means, stds)):
            ax.text(i, mean + std + max(means) * 0.01, 
                   f'{mean:.3f}±{std:.3f}', 
                   ha='center', va='bottom', fontsize=8)
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, "visualizations", "metrics_summary.png"),
                    dpi=150, bbox_inches='tight')
        plt.close()

# ============================================================================
# MULTI-MODEL COMPARISON
# ============================================================================

class MultiModelComparator:
    """
    Compare multiple models and generate comprehensive comparison reports.
    """
    
    def __init__(self):
        self.results = []
    
    def add_model_results(self, results):
        """Add evaluation results for a model"""
        self.results.append(results)
    
    def compare_models(self, save_dir):
        """Generate comprehensive comparison between models"""
        os.makedirs(save_dir, exist_ok=True)
        
        if len(self.results) < 2:
            print("Need at least 2 models for comparison")
            return
        
        print(f"\nGenerating comparison report for {len(self.results)} models...")
        
        # Create comparison DataFrame-like structure
        comparison_data = []
        for result in self.results:
            row = {'Model': result['model_type']}
            metrics = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'rmse', 'fpr', 'inference_time', 'reconstruction_time']
            
            for metric in metrics:
                mean_key = f'{metric}_mean'
                std_key = f'{metric}_std'
                if mean_key in result and std_key in result:
                    row[f'{metric}_mean'] = result[mean_key]
                    row[f'{metric}_std'] = result[std_key]
            
            row['parameters'] = result['model_parameters']
            row['samples'] = result['total_samples']
            comparison_data.append(row)
        
        # Save comparison data
        self._save_comparison_data(comparison_data, save_dir)
        
        # Create comparison visualizations
        self._create_comparison_plots(comparison_data, save_dir)
        
        # Generate comparison report
        self._generate_comparison_report(comparison_data, save_dir)
        
        print(f"Comparison results saved to {save_dir}")
    
    def _save_comparison_data(self, data, save_dir):
        """Save comparison data to CSV"""
        if not data:
            return
        
        with open(os.path.join(save_dir, 'model_comparison.csv'), 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=data[0].keys())
            writer.writeheader()
            writer.writerows(data)
    
    def _create_comparison_plots(self, data, save_dir):
        """Create comparison visualization plots"""
        if not data:
            return
        
        models = [row['Model'] for row in data]
        metrics = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'rmse', 'fpr', 'inference_time', 'reconstruction_time']
        
        # Create subplot for each metric
        fig, axes = plt.subplots(3, 3, figsize=(20, 15))
        axes = axes.flatten()
        
        for i, metric in enumerate(metrics):
            if i >= len(axes):
                break
                
            means = [row[f'{metric}_mean'] for row in data if f'{metric}_mean' in row]
            stds = [row[f'{metric}_std'] for row in data if f'{metric}_std' in row]
            
            if not means:
                continue
            
            bars = axes[i].bar(models, means, yerr=stds, capsize=5, alpha=0.7)
            
            # Color coding: green for "higher is better", red for "lower is better"
            color = 'green' if metric in ['psnr', 'ssim', 'ms_ssim'] else 'red'
            for bar in bars:
                bar.set_color(color)
            
            metric_title = metric.upper().replace('_', ' ')
            axes[i].set_title(f'{metric_title} Comparison')
            axes[i].set_ylabel(f'{metric_title}')
            axes[i].tick_params(axis='x', rotation=45)
            axes[i].grid(True, alpha=0.3)
            
            # Add value labels
            for j, (mean, std) in enumerate(zip(means, stds)):
                axes[i].text(j, mean + std + max(means) * 0.02,
                           f'{mean:.3f}±{std:.3f}',
                           ha='center', va='bottom', fontsize=8)
        
        # Hide unused subplots
        for i in range(len(metrics), len(axes)):
            axes[i].set_visible(False)
        
        plt.suptitle('Model Performance Comparison (Mean ± Std)', fontsize=16)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'comparison_plots.png'),
                    dpi=300, bbox_inches='tight')
        plt.close()
        
        # Create separate timing comparison plot
        self._create_timing_comparison_plot(data, save_dir)
        
        # Create ranking plot
        self._create_ranking_plot(data, save_dir)
    
    def _create_timing_comparison_plot(self, data, save_dir):
        """Create dedicated timing comparison plot"""
        models = [row['Model'] for row in data]
        
        # Extract timing data
        inference_means = [row['inference_time_mean'] for row in data if 'inference_time_mean' in row]
        inference_stds = [row['inference_time_std'] for row in data if 'inference_time_std' in row]
        recon_means = [row['reconstruction_time_mean'] for row in data if 'reconstruction_time_mean' in row]
        recon_stds = [row['reconstruction_time_std'] for row in data if 'reconstruction_time_std' in row]
        
        if not inference_means or not recon_means:
            return
        
        x = np.arange(len(models))
        width = 0.35
        
        fig, ax = plt.subplots(figsize=(12, 8))
        
        bars1 = ax.bar(x - width/2, inference_means, width, yerr=inference_stds, 
                       label='Inference Time', capsize=5, alpha=0.8, color='lightcoral')
        bars2 = ax.bar(x + width/2, recon_means, width, yerr=recon_stds,
                       label='Total Reconstruction Time', capsize=5, alpha=0.8, color='lightblue')
        
        ax.set_xlabel('Models')
        ax.set_ylabel('Time (seconds)')
        ax.set_title('Inference vs Total Reconstruction Time Comparison')
        ax.set_xticks(x)
        ax.set_xticklabels(models, rotation=45)
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add value labels
        for i, (inf_mean, inf_std, rec_mean, rec_std) in enumerate(zip(inference_means, inference_stds, recon_means, recon_stds)):
            ax.text(i - width/2, inf_mean + inf_std + max(inference_means) * 0.02,
                   f'{inf_mean:.3f}±{inf_std:.3f}', ha='center', va='bottom', fontsize=8)
            ax.text(i + width/2, rec_mean + rec_std + max(recon_means) * 0.02,
                   f'{rec_mean:.3f}±{rec_std:.3f}', ha='center', va='bottom', fontsize=8)
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'timing_comparison.png'),
                    dpi=300, bbox_inches='tight')
        plt.close()
    
    def _create_ranking_plot(self, data, save_dir):
        """Create model ranking visualization"""
        models = [row['Model'] for row in data]
        
        # Define ranking criteria (higher is better gets positive score)
        ranking_metrics = {
            'PSNR': ([row['psnr_mean'] for row in data], True),
            'SSIM': ([row['ssim_mean'] for row in data], True),  
            'MS-SSIM': ([row['ms_ssim_mean'] for row in data], True),
            'MAE': ([row['mae_mean'] for row in data], False),
            'MSE': ([row['mse_mean'] for row in data], False),
            'RMSE': ([row['rmse_mean'] for row in data], False),
            'FPR': ([row['fpr_mean'] for row in data], False),  # Lower FPR is better
            'Inference Speed': ([1/row['inference_time_mean'] for row in data], True),  # Inverse for speed
            'Reconstruction Speed': ([1/row['reconstruction_time_mean'] for row in data], True)  # Inverse for speed
        }
        
        # Calculate normalized scores
        model_scores = {model: 0 for model in models}
        
        for metric_name, (values, higher_better) in ranking_metrics.items():
            # Normalize to 0-1 range
            min_val, max_val = min(values), max(values)
            if max_val > min_val:
                normalized = [(v - min_val) / (max_val - min_val) for v in values]
                if not higher_better:
                    normalized = [1 - n for n in normalized]  # Flip for "lower is better"
                
                for model, score in zip(models, normalized):
                    model_scores[model] += score
        
        # Sort by total score
        sorted_models = sorted(model_scores.items(), key=lambda x: x[1], reverse=True)
        
        # Plot ranking
        fig, ax = plt.subplots(figsize=(12, 8))
        models_sorted = [m[0] for m in sorted_models]
        scores_sorted = [m[1] for m in sorted_models]
        
        bars = ax.bar(models_sorted, scores_sorted, alpha=0.7)
        
        # Color gradient
        colors = plt.cm.RdYlGn(np.linspace(0.3, 0.9, len(bars)))
        for bar, color in zip(bars, colors):
            bar.set_color(color)
        
        ax.set_title('Overall Model Ranking (Normalized Composite Score)\nIncludes Quality Metrics + FPR + Speed')
        ax.set_ylabel('Composite Score')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(True, alpha=0.3)
        
        # Add score labels
        for i, score in enumerate(scores_sorted):
            ax.text(i, score + max(scores_sorted) * 0.01,
                   f'{score:.2f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'model_ranking.png'),
                    dpi=300, bbox_inches='tight')
        plt.close()
    
    def _generate_comparison_report(self, data, save_dir):
        """Generate comprehensive comparison report"""
        with open(os.path.join(save_dir, 'comparison_report.txt'), 'w') as f:
            f.write("="*80 + "\n")
            f.write("COMPREHENSIVE MODEL COMPARISON REPORT\n")
            f.write("="*80 + "\n\n")
            
            f.write(f"Models Compared: {len(data)}\n")
            f.write(f"Comparison Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            # Summary table
            f.write("PERFORMANCE SUMMARY (Mean ± Std):\n")
            f.write("-" * 50 + "\n")
            
            f.write(f"{'Model':<15} {'PSNR':<12} {'SSIM':<12} {'MS-SSIM':<12} {'MAE':<12} {'MSE':<12} {'FPR':<12} {'Inf.Time':<12} {'Recon.Time':<12}\n")
            f.write("-" * 130 + "\n")
            
            for row in data:
                model_name = row['Model'][:14]  # Truncate long names
                psnr = f"{row['psnr_mean']:.2f}±{row['psnr_std']:.2f}"
                ssim = f"{row['ssim_mean']:.3f}±{row['ssim_std']:.3f}"
                ms_ssim = f"{row['ms_ssim_mean']:.3f}±{row['ms_ssim_std']:.3f}"
                mae = f"{row['mae_mean']:.3f}±{row['mae_std']:.3f}"
                mse = f"{row['mse_mean']:.3f}±{row['mse_std']:.3f}"
                fpr = f"{row['fpr_mean']:.3f}±{row['fpr_std']:.3f}"
                inf_time = f"{row['inference_time_mean']:.3f}±{row['inference_time_std']:.3f}"
                recon_time = f"{row['reconstruction_time_mean']:.3f}±{row['reconstruction_time_std']:.3f}"
                
                f.write(f"{model_name:<15} {psnr:<12} {ssim:<12} {ms_ssim:<12} {mae:<12} {mse:<12} {fpr:<12} {inf_time:<12} {recon_time:<12}\n")
            
            f.write("\n\nKEY FINDINGS:\n")
            f.write("-" * 15 + "\n")
            
            # Find best performers
            metrics_for_comparison = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'rmse', 'fpr', 'inference_time', 'reconstruction_time']
            higher_better = ['psnr', 'ssim', 'ms_ssim']
            
            for metric in metrics_for_comparison:
                mean_key = f'{metric}_mean'
                values = [row[mean_key] for row in data if mean_key in row]
                models = [row['Model'] for row in data if mean_key in row]
                
                if values:
                    if metric in higher_better:
                        best_idx = values.index(max(values))
                        f.write(f"Best {metric.upper().replace('_', ' ')}: {models[best_idx]} ({values[best_idx]:.4f})\n")
                    else:
                        best_idx = values.index(min(values))
                        f.write(f"Best {metric.upper().replace('_', ' ')}: {models[best_idx]} ({values[best_idx]:.4f})\n")
            
            f.write("\nNOTES:\n")
            f.write("- Higher PSNR, SSIM, MS-SSIM indicate better quality\n")
            f.write("- Lower MAE, MSE, RMSE indicate better accuracy\n")
            f.write("- Lower FPR indicates less hallucination (false positive regions)\n")
            f.write("- Lower inference and reconstruction times indicate faster processing\n")
            f.write("- Inference time: Pure model forward pass time\n")
            f.write("- Reconstruction time: Total time including data movement and post-processing\n")
            f.write("- FPR measures hallucinated pixels (pred > t_high AND target ≤ t_low)\n")
            f.write("- All values reported as mean ± standard deviation\n")

# ============================================================================
# MAIN EVALUATION FUNCTIONS
# ============================================================================

def evaluate_single_model(model_type, model_path, test_dir, save_dir, 
                         num_samples=None, batch_size=8, device='cuda',
                         fpr_t_high=0.5, fpr_t_low=0.2):
    """
    Evaluate a single model and return results.
    
    Args:
        model_type: Type of model ('transunet', 'unet', 'trust', 'restormer')
        model_path: Path to model checkpoint
        test_dir: Directory containing test images
        save_dir: Directory to save results
        num_samples: Number of samples to evaluate (None for all)
        batch_size: Batch size for evaluation
        device: Device to run on
        fpr_t_high: High threshold for FPR calculation
        fpr_t_low: Low threshold for FPR calculation
    
    Returns:
        dict: Evaluation results
    """
    evaluator = UnifiedModelEvaluator(model_type, model_path, device, fpr_t_high, fpr_t_low)
    results = evaluator.evaluate(test_dir, save_dir, num_samples, batch_size)
    return results

def evaluate_multiple_models(models_config, test_dir, comparison_save_dir,
                            num_samples=None, batch_size=8, device='cuda',
                            fpr_t_high=0.5, fpr_t_low=0.2):
    """
    Evaluate multiple models and generate comparison.
    
    Args:
        models_config: List of dicts with keys 'type', 'path', 'save_dir'
        test_dir: Directory containing test images  
        comparison_save_dir: Directory to save comparison results
        num_samples: Number of samples to evaluate
        batch_size: Batch size
        device: Device to run on
        fpr_t_high: High threshold for FPR calculation
        fpr_t_low: Low threshold for FPR calculation
    
    Returns:
        list: Results for all models
    """
    comparator = MultiModelComparator()
    all_results = []
    
    print(f"Starting evaluation of {len(models_config)} models...")
    print(f"FPR thresholds: t_high={fpr_t_high}, t_low={fpr_t_low}")
    
    for i, config in enumerate(models_config):
        print(f"\n[{i+1}/{len(models_config)}] Evaluating {config['type']}...")
        
        try:
            results = evaluate_single_model(
                model_type=config['type'],
                model_path=config['path'],
                test_dir=test_dir,
                save_dir=config['save_dir'],
                num_samples=num_samples,
                batch_size=batch_size,
                device=device,
                fpr_t_high=fpr_t_high,
                fpr_t_low=fpr_t_low
            )
            
            comparator.add_model_results(results)
            all_results.append(results)
            print(f"✓ {config['type']} evaluation completed")
            
        except Exception as e:
            print(f"✗ {config['type']} evaluation failed: {e}")
            continue
    
    # Generate comparison
    if len(all_results) > 1:
        print(f"\nGenerating comparison for {len(all_results)} models...")
        comparator.compare_models(comparison_save_dir)
    
    return all_results

# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    # Example configuration using your model paths
    models_config = [
        {
            'type': 'transunet',
            'path': r'F:\ImageNet\transunet_16x16_checkpoints\best_model.pth',
            'save_dir': r'F:\ImageNet\evaluation_results\transunet'
        },
        {
            'type': 'unet', 
            'path': r'F:\ImageNet\unet_16x16_checkpoints\best_model.pth',
            'save_dir': r'F:\ImageNet\evaluation_results\unet'
        },
        {
            'type': 'trust',
            'path': r'F:\ImageNet\trust_16x16_checkpoints\best_model.pth', 
            'save_dir': r'F:\ImageNet\evaluation_results\trust'
        },
        {
            'type': 'restormer',
            'path': r'F:\ImageNet\restormer_16x16_checkpoints\best_model.pth',
            'save_dir': r'F:\ImageNet\evaluation_results\restormer'
        }
    ]

    
    # Multi-model comparison example
    print("\n=== Multi-Model Comparison ===")
    results = evaluate_multiple_models(
        models_config=models_config,
        test_dir=r'F:\imgnet\data\test',  # Your test data directory
        comparison_save_dir=r'F:\ImageNet\evaluation_results\comparison',
        num_samples=1000,  # Evaluate on 50 samples
        batch_size=8,
        device='cuda'
    )
    
    print(f"\nEvaluation completed for {len(results)} models!")
    print("Check the F:\\ImageNet\\evaluation_results\\ directory for detailed outputs and comparisons.")
    
    # Print summary of results
    if results:
        print("\nQUICK SUMMARY:")
        print("-" * 50)
        for result in results:
            print(f"{result['model_type'].upper():12}: "
                  f"PSNR={result['psnr_mean']:.2f}±{result['psnr_std']:.2f}, "
                  f"SSIM={result['ssim_mean']:.3f}±{result['ssim_std']:.3f}, "
                  f"Time={result['inference_time_mean']:.3f}±{result['inference_time_std']:.3f}s")