#!/usr/bin/env python3
"""
Comprehensive Model Evaluation Script
====================================

This script provides comprehensive evaluation capabilities for various image reconstruction models
including L1 minimization, deep learning models (ViT, UNet, Restormer, etc.), and includes
advanced metrics like FPR (False Positive Regions) for hallucination detection.

Features:
- Multiple model support (TRUST, UNet, Restormer, TransUNet, etc.)
- Comprehensive metrics (PSNR, SSIM, MAE, MSE, FPR)
- Variance and statistical analysis
- Image saving and visualization
- Comparison across multiple models
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import scipy.io
import os
import time
import pandas as pd
from PIL import Image
import math
import sys

# Core imports
from scipy.optimize import linprog
from scipy.ndimage import gaussian_filter
from skimage.metrics import structural_similarity as ssim
from sklearn.linear_model import OrthogonalMatchingPursuit
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.transforms.functional import normalize, to_tensor
from transformers import ViTModel, AutoModelForImageClassification
import pytorch_ssim

# Add custom paths (modify as needed for your environment)
sys.path.append("/home/dan5/optics_recon/Optics_Recon_Project/Abalation/remove_skip")
sys.path.append("/home/dan5/optics_recon/Optics_Recon_Project/Abalation/non_pretrained")

# Import custom modules (these should be available in your environment)
try:
    from utilies import normalize_matrix, down_sample_matrix, mat2vec, vec2mat, mask_response_circle
    from load_data import load_and_process_data
    from vit import ViTForImageReconstruction
    from opticsNN_data_loader import transform_single_to_three_channel
except ImportError as e:
    print(f"Warning: Some custom modules not found: {e}")
    print("Please ensure all custom modules are in your Python path")


# ============================================================================
# L1 MINIMIZATION FUNCTIONS
# ============================================================================

def l1min(A, y):
    """
    Sparse recovery based on Linear Programming
    
    Parameters
    ----------
    A : numpy.ndarray
        Dictionary or Sensing Matrix
    y : numpy.ndarray
        Observed data
        
    Returns
    -------
    xr : numpy.ndarray
        Sparse codes
        
    Notes
    -----
    Python translation of MATLAB code by Trac D. Tran
    """
    M, N = A.shape
    
    # Create the augmented matrix [A, -A]
    A_tilde = np.hstack((A, -A))
    
    # Linear programming setup
    c = np.ones(2*N)  # Objective function coefficients
    
    # Inequality constraints: -I*x <= 0 (equivalent to x >= 0)
    A_ineq = -np.eye(2*N)
    b_ineq = np.zeros(2*N)
    
    # Equality constraint: A_tilde*x = y
    A_eq = A_tilde
    b_eq = y
    
    # Set options similar to MATLAB's 'dual-simplex' method
    options = {'disp': False}  # Suppress output
    
    # Solve the linear programming problem
    result = linprog(c, A_ub=A_ineq, b_ub=b_ineq, A_eq=A_eq, b_eq=b_eq, 
                    method='simplex', options=options)
    
    # Extract the solution
    x_tilde = result.x
    
    # Recover the original x from the augmented solution
    xr = x_tilde[:N] - x_tilde[N:2*N]
    
    return xr


# ============================================================================
# DATASET CLASS
# ============================================================================

class ResponseGTImageDataset(Dataset):
    def __init__(self, response_dir, gt_dir, transform=None, target_transform=None, start_idx=0, end_idx=None):
        """
        Initialize dataset with customizable starting and ending indices.
        
        Args:
        response_dir (str): Directory containing the response images.
        gt_dir (str): Directory containing the ground truth images.
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version.
        target_transform (callable, optional): A function/transform that takes in the GT image and returns a transformed version.
        start_idx (int, optional): The starting index to load data from (0-indexed). Default is 0.
        end_idx (int, optional): The ending index to load data until (exclusive). Default is None (load all available).
        """
        self.response_dir = response_dir
        self.gt_dir = gt_dir
        self.response_files = [f for f in sorted(os.listdir(response_dir)) if f.endswith('.mat')]
        self.transform = transform
        self.target_transform = target_transform
        
        # Set the starting and ending indices
        self.start_idx = start_idx
        self.end_idx = end_idx if end_idx is not None else int(len(self.response_files)/2)
        

    def __len__(self):
        return self.end_idx - self.start_idx

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError(f"Index {idx} out of range for dataset of length {len(self)}")
        # Adjust the index by adding the start_idx offset
        actual_idx = idx + self.start_idx
        response_path = os.path.join(self.response_dir, f's{actual_idx}.mat')
        gt_path = os.path.join(self.gt_dir, f'gt{actual_idx}.mat')

        # Load MAT files
        mat_response = scipy.io.loadmat(response_path)
        mat_gt = scipy.io.loadmat(gt_path)

        # Process response data
        mat_response_content = mat_response['img1'].astype(np.float32)
        mat_response_content_part = mat_response_content[300:1500, 480:1680]
        
        # Apply mask if function is available
        try:
            mask = mask_response_circle(mat_response_content_part)
            mat_response_content_part_masked = mat_response_content_part * mask
        except NameError:
            mat_response_content_part_masked = mat_response_content_part
            
        mat_response_content_part_masked[mat_response_content_part_masked < 0] = 0

        # Process ground truth
        mat_gt_content = mat_gt['img0'].astype(np.float32)
        mat_gt_content[mat_gt_content < 0] = 0
        mat_gt_content_part = mat_gt_content[1070:1570, 1080:1580]

        # Normalize if functions are available
        try:
            mat_response_normalized = normalize_matrix(mat_response_content_part_masked)
            mat_gt_normalized = normalize_matrix(mat_gt_content_part)
            
            mat_response_down = down_sample_matrix(mat_response_normalized, 25)
            mat_gt_down = down_sample_matrix(mat_gt_normalized, 5)
            
            mat_response_down = normalize_matrix(mat_response_down)
            mat_gt_down = normalize_matrix(mat_gt_down)
        except NameError:
            # Fallback normalization
            mat_response_down = (mat_response_content_part_masked - mat_response_content_part_masked.min()) / \
                              (mat_response_content_part_masked.max() - mat_response_content_part_masked.min() + 1e-8)
            mat_gt_down = (mat_gt_content_part - mat_gt_content_part.min()) / \
                         (mat_gt_content_part.max() - mat_gt_content_part.min() + 1e-8)

        # Apply additional transforms if any
        if self.transform:
            mat_response_down = self.transform(mat_response_down)
        if self.target_transform:
            mat_gt_down = self.target_transform(mat_gt_down)

        # Final normalization
        try:
            return normalize_matrix(mat_response_down), normalize_matrix(mat_gt_down)
        except NameError:
            return mat_response_down, mat_gt_down


# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def transform_single_to_three_channel(x):
    """Convert single-channel image to three-channel"""
    if x.shape[0] == 1:
        return x.repeat(3, 1, 1)
    return x


def psnr(target, ref):
    """Calculate PSNR between target and reference images"""
    target_data = np.array(target, dtype=np.float64)
    ref_data = np.array(ref, dtype=np.float64)
    mse = np.mean((target_data - ref_data) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))


def calculate_ssim(img1, img2, win_size=7):
    """Calculate SSIM with adaptive window size"""
    # Make sure images are large enough for the window size
    min_dim = min(img1.shape[0], img1.shape[1], img2.shape[0], img2.shape[1])
    
    # Adjust window size if needed
    if min_dim < win_size:
        win_size = min_dim if min_dim % 2 == 1 else min_dim - 1
        win_size = max(3, win_size)
    
    # Handle multichannel images appropriately
    if len(img1.shape) == 3 and img1.shape[2] > 1:
        return ssim(img1, img2, win_size=win_size, data_range=img2.max() - img2.min(), channel_axis=2)
    else:
        return ssim(img1, img2, win_size=win_size, data_range=img2.max() - img2.min())


def mae(imageA, imageB):
    """Calculate Mean Absolute Error between two images"""
    abs_error = np.abs(imageA.astype("float") - imageB.astype("float"))
    return np.mean(abs_error)


def mse(imageA, imageB):
    """Calculate Mean Squared Error between two images"""
    err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    err /= float(imageA.shape[0] * imageA.shape[1])
    return err


def compute_fpr_score(x_hat, x_true, t_high=0.5, t_low=0.2):
    """
    Compute False Positive Regions (FPR) hallucination score
    
    Parameters
    ----------
    x_hat : numpy.ndarray
        Generated or reconstructed image
    x_true : numpy.ndarray
        Ground truth image
    t_high : float
        Upper threshold for generated image
    t_low : float
        Lower threshold for ground truth
        
    Returns
    -------
    score : float
        Hallucination score (fraction of hallucinated pixels)
    mask : numpy.ndarray
        Binary hallucination mask
    """
    # Ensure both inputs are normalized to [0,1] range
    x_hat = (x_hat - x_hat.min()) / (x_hat.max() - x_hat.min() + 1e-8)
    x_true = (x_true - x_true.min()) / (x_true.max() - x_true.min() + 1e-8)
    
    # Define hallucination mask H = (x_hat > t_high) ∧ (x_true ≤ t_low)
    H = np.logical_and(x_hat > t_high, x_true <= t_low)
    
    # Compute hallucination score as fraction of hallucinated pixels
    n = x_hat.size
    hallucination_score = np.sum(H) / n
    
    return hallucination_score, H


# ============================================================================
# MAIN EVALUATION FUNCTION
# ============================================================================

def comprehensive_eval_new2(test_response_file_path, test_gt_file_path, start_idx, end_idx, 
                      model_type, model_path, save_dir, t_high=0.5, t_low=0.2):
    """
    Comprehensive evaluation including FPR for hallucination detection
    and saving reconstructed images. Added variance calculation for all metrics.
    Fixed tensor type mismatch issues.
    """
    # Create save directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Create a directory for reconstructed images
    recon_img_dir = os.path.join(save_dir, 'reconstructed_images')
    if not os.path.exists(recon_img_dir):
        os.makedirs(recon_img_dir)
    
    # Define the transformations - ENSURE FLOAT32 TENSORS
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(transform_single_to_three_channel),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.Lambda(lambda x: x.float())  # Ensure float32
    ])

    target_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(transform_single_to_three_channel),
        transforms.Resize((449, 449)),
        transforms.Lambda(lambda x: x.float())  # Ensure float32
    ])

    # Create dataset
    test_dataset = ResponseGTImageDataset(
        response_dir=test_response_file_path,
        gt_dir=test_gt_file_path,
        transform=transform,
        target_transform=target_transform,
        start_idx=start_idx,
        end_idx=end_idx
    )

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize model based on model_type
    if model_type == 'TRUST_no_skip':
        try:
            model = ViTForImageReconstruction("google/vit-base-patch16-224").to(device)
        except:
            print("Warning: ViTForImageReconstruction not available")
            return None
            
    elif model_type == 'unet':
        try:
            from unet_ddp import UNet
            model = UNet(n_channels=3, n_classes=3, bilinear=True, output_size=(449, 449)).to(device)
        except ImportError:
            print("Warning: UNet model not available")
            return None
            
    elif model_type == 'Transunet':
        try:
            from transunet import TransUNet, SSIMLoss
            from transunet_ddp import create_transunet_model
            model = create_transunet_model(img_size=224, in_channels=3, out_channels=3, device=device)
        except ImportError:
            print("Warning: TransUNet model not available")
            return None
            
    elif model_type == 'bottom_skip':
        try:
            from vunt_ddp import ViTUNetHybridFixed
            model = ViTUNetHybridFixed(pretrained_model_name="google/vit-base-patch16-224", output_size=(449, 449)).to(device)
        except ImportError:
            print("Warning: ViTUNetHybridFixed model not available")
            return None
            
    elif model_type == 'TRUST':
        try:
            from vunt_early_ddp import ViTUNetForImageReconstruction
            model = ViTUNetForImageReconstruction("google/vit-base-patch16-224").to(device)
        except ImportError:
            print("Warning: TRUST model not available")
            return None
            
    elif model_type == 'TRUST_mv_skip1':
        try:
            from vunt_earlt_removeskip1 import ViTUNetForImageReconstruction
            model = ViTUNetForImageReconstruction("google/vit-base-patch16-224").to(device)
        except ImportError:
            print("Warning: TRUST_mv_skip1 model not available")
            return None
            
    elif model_type == 'TRUST_mv_skip1_skip2':
        try:
            from vunt_early_remove12 import ViTUNetForImageReconstruction
            model = ViTUNetForImageReconstruction("google/vit-base-patch16-224").to(device)
        except ImportError:
            print("Warning: TRUST_mv_skip1_skip2 model not available")
            return None
            
    elif model_type == 'TRUST_nopre':
        try:
            from vunt_early_nonpretrained_ddp import ViTUNetForImageReconstruction
            from transformers import ViTConfig, ViTModel
            vit_config = ViTConfig(
                hidden_size=768,
                num_hidden_layers=12,
                num_attention_heads=12,
                intermediate_size=3072,
                hidden_act="gelu",
                image_size=224,
                patch_size=16,
                num_channels=3
            )
            model = ViTUNetForImageReconstruction(pretrained=False, model_name_or_config=vit_config).to(device)
        except ImportError:
            print("Warning: TRUST_nopre model not available")
            return None

    elif model_type == 'Restormer':
        try:
            from restormer_ddp_modified import RestormerOptics
            model = RestormerOptics(
                inp_channels=3,  # For RGB images
                out_channels=3,  # For RGB images
                dim=48,
                num_blocks=[4, 6, 6, 8],
                num_heads=[1, 2, 4, 8],
                ffn_expansion_factor=2.66,
                bias=False
            ).to(device)
        except ImportError:
            print("Warning: Restormer model not available")
            return None

    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    
    # Load model parameters
    try:
        model.load_state_dict(torch.load(model_path, weights_only=True, map_location=device))
        model.eval()
        model = model.float()  # Ensure model is in float32
        print(f"Successfully loaded model: {model_type}")
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

    # Create dataloaders
    batch_size = 1
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Lists to store metrics
    inference_times = []
    psnr_values = []
    ssim_values = []
    mae_values = []
    mse_values = []
    fpr_values = []
    results = []

    # Helper function to save tensor as image
    def save_tensor_as_image(tensor, path):
        img_np = tensor.cpu().numpy()
        img_np = np.transpose(img_np, (1, 2, 0))  # Convert from CxHxW to HxWxC
        img_np = np.clip(img_np, 0, 1)  # Clip values to [0, 1] range
        img_np = (img_np * 255).astype(np.uint8)  # Convert to uint8 [0, 255]
        img = Image.fromarray(img_np)
        img.save(path)

    # Process each image and calculate metrics
    print(f"Processing {len(test_loader)} images...")
    for i, (input_images, target_images) in enumerate(test_loader):
        # Move to device and ensure FLOAT32 type
        input_images = input_images.to(device).float()
        target_images = target_images.to(device).float()
        
        # Measure inference time
        start_time = time.time()
        
        # Forward pass
        with torch.no_grad():
            outputs = model(input_images)
        
        # Calculate inference time
        end_time = time.time()
        inference_time = end_time - start_time
        inference_times.append(inference_time)
        
        # Store results for visualization and metric calculation
        results.append((input_images, target_images, outputs))
        
        # Convert tensors to numpy for metric calculation
        target_np = target_images.cpu().numpy()[0]
        output_np = outputs.cpu().numpy()[0]
        
        # Convert from channel-first to channel-last for metric functions
        target_np_display = np.transpose(target_np, (1, 2, 0))
        output_np_display = np.transpose(output_np, (1, 2, 0))
        
        # Calculate standard quality metrics
        current_psnr = psnr(output_np_display[:,:,0], target_np_display[:,:,0])
        
        try:
            current_ssim = calculate_ssim(output_np_display, target_np_display)
        except Exception as e:
            print(f"Warning: SSIM calculation failed for image {i+1}: {e}")
            current_ssim = 0.0
            
        current_mae = mae(output_np_display[:,:,0], target_np_display[:,:,0])
        current_mse = mse(output_np_display[:,:,0], target_np_display[:,:,0])
        
        # Calculate FPR hallucination score
        current_fpr, fpr_mask = compute_fpr_score(
            output_np_display[:,:,0], 
            target_np_display[:,:,0],
            t_high=t_high,
            t_low=t_low
        )
        
        # Store metrics
        psnr_values.append(current_psnr)
        ssim_values.append(current_ssim)
        mae_values.append(current_mae)
        mse_values.append(current_mse)
        fpr_values.append(current_fpr)
        
        # Save reconstructed image
        img_filename = f'image_{i+1:04d}.png'
        recon_img_path = os.path.join(recon_img_dir, img_filename)
        save_tensor_as_image(outputs[0], recon_img_path)
        
        # Save input and target images for comparison
        input_img_path = os.path.join(recon_img_dir, f'input_{i+1:04d}.png')
        target_img_path = os.path.join(recon_img_dir, f'target_{i+1:04d}.png')
        
        # Denormalize input image before saving
        input_denorm = input_images[0].clone()
        input_denorm[0] = input_denorm[0] * 0.229 + 0.485
        input_denorm[1] = input_denorm[1] * 0.224 + 0.456
        input_denorm[2] = input_denorm[2] * 0.224 + 0.406
        
        save_tensor_as_image(input_denorm, input_img_path)
        save_tensor_as_image(target_images[0], target_img_path)
        
        # Save comparison visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Convert input to grayscale for visualization
        r, g, b = input_denorm[0], input_denorm[1], input_denorm[2]
        gray_img = (0.2989 * r + 0.5870 * g + 0.1140 * b).cpu().numpy()
        axes[0].imshow(gray_img, cmap='gray')
        axes[0].set_title('Input Image')
        axes[0].axis('off')
        
        axes[1].imshow(target_np_display)
        axes[1].set_title('Ground Truth')
        axes[1].axis('off')
        
        axes[2].imshow(output_np_display)
        axes[2].set_title(f'Reconstructed\nPSNR: {current_psnr:.2f}, FPR: {current_fpr:.4f}')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(recon_img_dir, f'comparison_{i+1:04d}.png'))
        plt.close()
        
        # Save the FPR mask for visualization
        plt.figure(figsize=(8, 8))
        plt.imshow(fpr_mask, cmap='hot')
        plt.colorbar(label='FPR Score')
        plt.title(f'FPR Mask - Score: {current_fpr:.4f}')
        plt.savefig(os.path.join(recon_img_dir, f'fpr_mask_{i+1:04d}.png'))
        plt.close()
        
        print(f"Image {i+1}: Time = {inference_time:.4f}s, PSNR = {current_psnr:.2f}, SSIM = {current_ssim:.4f}, "
              f"MAE = {current_mae:.4f}, MSE = {current_mse:.4f}, FPR = {current_fpr:.4f}")

    # Calculate average metrics and variance
    avg_inference_time = sum(inference_times) / len(inference_times)
    avg_psnr = sum(psnr_values) / len(psnr_values)
    avg_ssim = sum(ssim_values) / len(ssim_values)
    avg_mae = sum(mae_values) / len(mae_values)
    avg_mse = sum(mse_values) / len(mse_values)
    avg_fpr = sum(fpr_values) / len(fpr_values)
    
    # Calculate variance for each metric
    var_inference_time = np.var(inference_times, ddof=1) if len(inference_times) > 1 else 0
    var_psnr = np.var(psnr_values, ddof=1) if len(psnr_values) > 1 else 0
    var_ssim = np.var(ssim_values, ddof=1) if len(ssim_values) > 1 else 0
    var_mae = np.var(mae_values, ddof=1) if len(mae_values) > 1 else 0
    var_mse = np.var(mse_values, ddof=1) if len(mse_values) > 1 else 0
    var_fpr = np.var(fpr_values, ddof=1) if len(fpr_values) > 1 else 0
    
    # Calculate standard deviation for each metric
    std_inference_time = np.sqrt(var_inference_time)
    std_psnr = np.sqrt(var_psnr)
    std_ssim = np.sqrt(var_ssim)
    std_mae = np.sqrt(var_mae)
    std_mse = np.sqrt(var_mse)
    std_fpr = np.sqrt(var_fpr)
    
    # Print average metrics with variance
    print(f"\n{'='*60}")
    print(f"EVALUATION RESULTS FOR {model_type.upper()}")
    print(f"{'='*60}")
    print(f"Average metrics across {len(psnr_values)} images:")
    print(f"Inference Time: {avg_inference_time:.4f} ± {std_inference_time:.4f} seconds")
    print(f"PSNR: {avg_psnr:.2f} ± {std_psnr:.2f} dB")
    print(f"SSIM: {avg_ssim:.4f} ± {std_ssim:.4f}")
    print(f"MAE: {avg_mae:.4f} ± {std_mae:.4f}")
    print(f"MSE: {avg_mse:.4f} ± {std_mse:.4f}")
    print(f"FPR: {avg_fpr:.4f} ± {std_fpr:.4f}")

    # Create a DataFrame to store per-image metrics
    metrics_df = pd.DataFrame({
        'Image': [f'Image_{i+1}' for i in range(len(psnr_values))],
        'Inference_Time': inference_times,
        'PSNR': psnr_values,
        'SSIM': ssim_values,
        'MAE': mae_values,
        'MSE': mse_values,
        'FPR': fpr_values,
        'ReconstructedImagePath': [os.path.join(recon_img_dir, f'image_{i+1:04d}.png') for i in range(len(psnr_values))]
    })
    
    # Save metrics to CSV
    metrics_df.to_csv(f'{save_dir}/metrics_results.csv', index=False)
    
    # Include variance information in the CSV
    variance_df = pd.DataFrame({
        'Metric': ['Inference_Time', 'PSNR', 'SSIM', 'MAE', 'MSE', 'FPR'],
        'Average': [avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr],
        'Variance': [var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr],
        'Std_Dev': [std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr]
    })
    variance_df.to_csv(f'{save_dir}/metrics_variance.csv', index=False)
    
    # Create visualizations
    create_evaluation_plots(save_dir, model_type, inference_times, psnr_values, ssim_values, 
                           mae_values, mse_values, fpr_values, 
                           avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr,
                           std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr,
                           var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr)
    
    # Save metrics summary to text file
    save_metrics_summary(save_dir, model_type, model_path, len(psnr_values), t_high, t_low,
                        avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr,
                        std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr,
                        var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr,
                        recon_img_dir)
    
    # Return dictionary with all metrics including FPR and variance values
    return {
        'inference_time': avg_inference_time,
        'inference_time_std': std_inference_time,
        'inference_time_var': var_inference_time,
        'psnr': avg_psnr,
        'psnr_std': std_psnr,
        'psnr_var': var_psnr,
        'ssim': avg_ssim,
        'ssim_std': std_ssim,
        'ssim_var': var_ssim,
        'mae': avg_mae,
        'mae_std': std_mae,
        'mae_var': var_mae,
        'mse': avg_mse,
        'mse_std': std_mse,
        'mse_var': var_mse,
        'fpr': avg_fpr,
        'fpr_std': std_fpr,
        'fpr_var': var_fpr,
        'model_type': model_type,
        'images_processed': len(psnr_values),
        'reconstructed_images_dir': recon_img_dir
    }


# ============================================================================
# VISUALIZATION AND REPORTING FUNCTIONS
# ============================================================================

def create_evaluation_plots(save_dir, model_type, inference_times, psnr_values, ssim_values, 
                           mae_values, mse_values, fpr_values, 
                           avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr,
                           std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr,
                           var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr):
    """Create comprehensive visualization plots for evaluation metrics"""
    
    colors = ['blue', 'green', 'red', 'purple', 'orange', 'magenta']
    
    # Create a 3x2 subplot for all metrics including FPR
    fig, axs = plt.subplots(3, 2, figsize=(16, 15))
    
    # Get image indices for x-axis
    image_indices = [i+1 for i in range(len(psnr_values))]

    # Inference Time trend
    axs[0, 0].plot(image_indices, inference_times, marker='o', color=colors[4], linestyle='-')
    axs[0, 0].set_title('Inference Time Trend (Lower is better)')
    axs[0, 0].set_xlabel('Image Number')
    axs[0, 0].set_ylabel('Time (seconds)')
    axs[0, 0].grid(True, linestyle='--', alpha=0.7)
    axs[0, 0].text(image_indices[-1] * 0.05, max(inference_times) * 0.95, 
                f'Avg: {avg_inference_time:.4f} ± {std_inference_time:.4f}s', 
                bbox=dict(facecolor='white', alpha=0.8))

    # PSNR trend
    axs[0, 1].plot(image_indices, psnr_values, marker='o', color=colors[0], linestyle='-')
    axs[0, 1].set_title('PSNR Trend (Higher is better)')
    axs[0, 1].set_xlabel('Image Number')
    axs[0, 1].set_ylabel('PSNR (dB)')
    axs[0, 1].grid(True, linestyle='--', alpha=0.7)
    axs[0, 1].text(image_indices[-1] * 0.05, max(psnr_values) * 0.95, 
                f'Avg: {avg_psnr:.2f} ± {std_psnr:.2f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # SSIM trend
    axs[1, 0].plot(image_indices, ssim_values, marker='o', color=colors[1], linestyle='-')
    axs[1, 0].set_title('SSIM Trend (Higher is better)')
    axs[1, 0].set_xlabel('Image Number')
    axs[1, 0].set_ylabel('SSIM')
    axs[1, 0].grid(True, linestyle='--', alpha=0.7)
    axs[1, 0].text(image_indices[-1] * 0.05, max(ssim_values) * 0.95, 
                f'Avg: {avg_ssim:.4f} ± {std_ssim:.4f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # FPR trend
    axs[1, 1].plot(image_indices, fpr_values, marker='o', color=colors[5], linestyle='-')
    axs[1, 1].set_title('FPR Trend (Lower is better)')
    axs[1, 1].set_xlabel('Image Number')
    axs[1, 1].set_ylabel('FPR Score')
    axs[1, 1].grid(True, linestyle='--', alpha=0.7)
    axs[1, 1].text(image_indices[-1] * 0.05, max(fpr_values) * 0.95, 
                f'Avg: {avg_fpr:.4f} ± {std_fpr:.4f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # MSE trend
    axs[2, 0].plot(image_indices, mse_values, marker='o', color=colors[3], linestyle='-')
    axs[2, 0].set_title('MSE Trend (Lower is better)')
    axs[2, 0].set_xlabel('Image Number')
    axs[2, 0].set_ylabel('MSE')
    axs[2, 0].grid(True, linestyle='--', alpha=0.7)
    axs[2, 0].text(image_indices[-1] * 0.05, max(mse_values) * 0.95, 
                f'Avg: {avg_mse:.4f} ± {std_mse:.4f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # MAE trend
    axs[2, 1].plot(image_indices, mae_values, marker='o', color=colors[2], linestyle='-')
    axs[2, 1].set_title('MAE Trend (Lower is better)')
    axs[2, 1].set_xlabel('Image Number')
    axs[2, 1].set_ylabel('MAE')
    axs[2, 1].grid(True, linestyle='--', alpha=0.7)
    axs[2, 1].text(image_indices[-1] * 0.05, max(mae_values) * 0.95, 
                f'Avg: {avg_mae:.4f} ± {std_mae:.4f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # Save trend plot
    plt.suptitle(f'Comprehensive Evaluation with FPR - {model_type.upper()} Model', fontsize=16)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/comprehensive_metrics_with_fpr.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Create a box plot for all metrics to visualize variance
    plt.figure(figsize=(15, 10))
    
    # Normalize metrics for better visualization in one plot
    norm_inference = [x/max(inference_times) for x in inference_times] if max(inference_times) > 0 else [0]*len(inference_times)
    norm_psnr = [x/max(psnr_values) for x in psnr_values] if max(psnr_values) > 0 else [0]*len(psnr_values)
    norm_ssim = ssim_values  # SSIM is already normalized between 0-1
    norm_mae = [x/max(mae_values) for x in mae_values] if max(mae_values) > 0 else [0]*len(mae_values)
    norm_mse = [x/max(mse_values) for x in mse_values] if max(mse_values) > 0 else [0]*len(mse_values)
    norm_fpr = [x/max(fpr_values) for x in fpr_values] if max(fpr_values) > 0 else [0]*len(fpr_values)
    
    # Create box plot data
    box_data = [norm_psnr, norm_ssim, norm_inference, norm_mae, norm_mse, norm_fpr]
    
    # Create boxplot
    bp = plt.boxplot(box_data, patch_artist=True)
    
    # Set colors for each box
    colors = ['lightblue', 'lightgreen', 'coral', 'orchid', 'wheat', 'pink']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    # Set labels and title
    plt.title('Distribution of Normalized Metrics (Box Plot)', fontsize=14)
    plt.xlabel('Metrics')
    plt.ylabel('Normalized Values')
    plt.xticks([1, 2, 3, 4, 5, 6], ['PSNR', 'SSIM', 'Inference', 'MAE', 'MSE', 'FPR'])
    plt.grid(True, linestyle='--', alpha=0.3)
    
    # Add annotation about variance
    plt.figtext(0.5, 0.01, 
                f'Variance: PSNR: {var_psnr:.4f}, SSIM: {var_ssim:.4f}, Time: {var_inference_time:.4f}, '
                f'MAE: {var_mae:.4f}, MSE: {var_mse:.4f}, FPR: {var_fpr:.4f}',
                ha='center', fontsize=10)
    
    # Save the plot
    plt.tight_layout()
    plt.savefig(f'{save_dir}/metrics_distribution.png', dpi=300, bbox_inches='tight')
    plt.close()


def save_metrics_summary(save_dir, model_type, model_path, num_images, t_high, t_low,
                        avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr,
                        std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr,
                        var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr,
                        recon_img_dir):
    """Save comprehensive metrics summary to text file"""
    
    with open(f'{save_dir}/metrics_summary_with_variance.txt', 'w') as f:
        f.write("="*80 + "\n")
        f.write("COMPREHENSIVE MODEL EVALUATION REPORT\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Model Type: {model_type}\n")
        f.write(f"Model Path: {model_path}\n")
        f.write(f"Total Images Processed: {num_images}\n")
        f.write(f"FPR Threshold Parameters: t_high={t_high}, t_low={t_low}\n\n")
        
        f.write("AVERAGE METRICS WITH STATISTICAL ANALYSIS:\n")
        f.write("-" * 50 + "\n")
        f.write(f"Inference Time: {avg_inference_time:.4f} ± {std_inference_time:.4f} seconds (var: {var_inference_time:.6f})\n")
        f.write(f"PSNR: {avg_psnr:.4f} ± {std_psnr:.4f} dB (var: {var_psnr:.6f})\n")
        f.write(f"SSIM: {avg_ssim:.4f} ± {std_ssim:.4f} (var: {var_ssim:.6f})\n")
        f.write(f"MAE: {avg_mae:.4f} ± {std_mae:.4f} (var: {var_mae:.6f})\n")
        f.write(f"MSE: {avg_mse:.4f} ± {std_mse:.4f} (var: {var_mse:.6f})\n")
        f.write(f"FPR: {avg_fpr:.4f} ± {std_fpr:.4f} (var: {var_fpr:.6f})\n\n")
        
        f.write("INTERPRETATION:\n")
        f.write("-" * 20 + "\n")
        f.write("• PSNR (Peak Signal-to-Noise Ratio): Higher values indicate better quality\n")
        f.write("• SSIM (Structural Similarity Index): Range [0,1], higher values indicate better structural similarity\n")
        f.write("• MAE (Mean Absolute Error): Lower values indicate better accuracy\n")
        f.write("• MSE (Mean Squared Error): Lower values indicate better accuracy\n")
        f.write("• FPR (False Positive Regions): Lower values indicate less hallucination\n")
        f.write("• Inference Time: Lower values indicate faster processing\n\n")
        
        f.write(f"Reconstructed images saved to: {recon_img_dir}\n")
        f.write(f"Generated plots saved to: {save_dir}\n\n")
        
        f.write("FILES GENERATED:\n")
        f.write("-" * 15 + "\n")
        f.write("• metrics_results.csv - Per-image detailed metrics\n")
        f.write("• metrics_variance.csv - Statistical summary\n")
        f.write("• comprehensive_metrics_with_fpr.png - Trend plots\n")
        f.write("• metrics_distribution.png - Distribution analysis\n")
        f.write("• reconstructed_images/ - All output images and comparisons\n")


# ============================================================================
# TRANSUNET SPECIFIC EVALUATION FUNCTION
# ============================================================================

def comprehensive_eval_transunet(test_response_file_path, test_gt_file_path, start_idx, end_idx, 
                               model_path, save_dir, t_high=0.5, t_low=0.2):
    """
    Comprehensive evaluation specifically for TransUNet model
    - Input: 224x224
    - Output: 224x224
    - Handles size differences from other models
    """
    import torch
    import time
    import numpy as np
    import matplotlib.pyplot as plt
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import os
    import pandas as pd
    from PIL import Image
    import torch.nn.functional as F
    
    # Create save directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Create a directory for reconstructed images
    recon_img_dir = os.path.join(save_dir, 'reconstructed_images')
    if not os.path.exists(recon_img_dir):
        os.makedirs(recon_img_dir)
    
    # Define the transformations for TransUNet - BOTH INPUT AND TARGET ARE 224x224
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(transform_single_to_three_channel),
        transforms.Resize((224, 224)),  # TransUNet input size
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.Lambda(lambda x: x.float())  # Ensure float32
    ])

    # For TransUNet, target should also be 224x224 to match output
    target_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(transform_single_to_three_channel),
        transforms.Resize((224, 224)),  # Same as input for TransUNet
        transforms.Lambda(lambda x: x.float())  # Ensure float32
    ])

    # Create dataset
    test_dataset = ResponseGTImageDataset(
        response_dir=test_response_file_path,
        gt_dir=test_gt_file_path,
        transform=transform,
        target_transform=target_transform,
        start_idx=start_idx,
        end_idx=end_idx
    )

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize TransUNet model
    try:
        from transunet import TransUNet, SSIMLoss
        from transunet_ddp import create_transunet_model
        
        model = create_transunet_model(img_size=224, in_channels=3, out_channels=3, device=device)
        
        # Load model parameters
        model.load_state_dict(torch.load(model_path, weights_only=True, map_location=device))
        model.eval()
        model = model.float()  # Ensure model is in float32
        print("Successfully loaded TransUNet model")
        
    except ImportError:
        print("Error: TransUNet model not available. Please ensure transunet modules are installed.")
        return None
    except Exception as e:
        print(f"Error loading TransUNet model: {e}")
        return None

    # Create dataloaders
    batch_size = 1
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Lists to store metrics
    inference_times = []
    psnr_values = []
    ssim_values = []
    mae_values = []
    mse_values = []
    fpr_values = []
    results = []

    # Helper function to save tensor as image
    def save_tensor_as_image(tensor, path):
        img_np = tensor.cpu().numpy()
        img_np = np.transpose(img_np, (1, 2, 0))  # Convert from CxHxW to HxWxC
        img_np = np.clip(img_np, 0, 1)  # Clip values to [0, 1] range
        img_np = (img_np * 255).astype(np.uint8)  # Convert to uint8 [0, 255]
        img = Image.fromarray(img_np)
        img.save(path)

    print(f"Starting TransUNet evaluation on {len(test_dataset)} images...")
    print(f"Input size: 224x224, Output size: 224x224")

    # Process each image and calculate metrics
    for i, (input_images, target_images) in enumerate(test_loader):
        # Move to device and ensure FLOAT32 type
        input_images = input_images.to(device).float()
        target_images = target_images.to(device).float()
        
        # Verify sizes for TransUNet
        assert input_images.shape[2:] == (224, 224), f"Input size mismatch: {input_images.shape[2:]}"
        assert target_images.shape[2:] == (224, 224), f"Target size mismatch: {target_images.shape[2:]}"
        
        # Measure inference time
        start_time = time.time()
        
        # Forward pass
        with torch.no_grad():
            outputs = model(input_images)
        
        # Calculate inference time
        end_time = time.time()
        inference_time = end_time - start_time
        inference_times.append(inference_time)
        
        # Verify output size
        assert outputs.shape[2:] == (224, 224), f"Output size mismatch: {outputs.shape[2:]}"
        
        # Store results for visualization and metric calculation
        results.append((input_images, target_images, outputs))
        
        # Convert tensors to numpy for metric calculation
        target_np = target_images.cpu().numpy()[0]
        output_np = outputs.cpu().numpy()[0]
        
        # Convert from channel-first to channel-last for metric functions
        target_np_display = np.transpose(target_np, (1, 2, 0))
        output_np_display = np.transpose(output_np, (1, 2, 0))
        
        # Calculate standard quality metrics
        current_psnr = psnr(output_np_display[:,:,0], target_np_display[:,:,0])
        
        try:
            current_ssim = calculate_ssim(output_np_display, target_np_display)
        except Exception as e:
            print(f"Warning: SSIM calculation failed for image {i+1}: {e}")
            current_ssim = 0.0
            
        current_mae = mae(output_np_display[:,:,0], target_np_display[:,:,0])
        current_mse = mse(output_np_display[:,:,0], target_np_display[:,:,0])
        
        # Calculate FPR hallucination score
        current_fpr, fpr_mask = compute_fpr_score(
            output_np_display[:,:,0], 
            target_np_display[:,:,0],
            t_high=t_high,
            t_low=t_low
        )
        
        # Store metrics
        psnr_values.append(current_psnr)
        ssim_values.append(current_ssim)
        mae_values.append(current_mae)
        mse_values.append(current_mse)
        fpr_values.append(current_fpr)
        
        # Save reconstructed image
        img_filename = f'image_{i+1:04d}.png'
        recon_img_path = os.path.join(recon_img_dir, img_filename)
        save_tensor_as_image(outputs[0], recon_img_path)
        
        # Save input and target images for comparison
        input_img_path = os.path.join(recon_img_dir, f'input_{i+1:04d}.png')
        target_img_path = os.path.join(recon_img_dir, f'target_{i+1:04d}.png')
        
        # Denormalize input image before saving
        input_denorm = input_images[0].clone()
        input_denorm[0] = input_denorm[0] * 0.229 + 0.485
        input_denorm[1] = input_denorm[1] * 0.224 + 0.456
        input_denorm[2] = input_denorm[2] * 0.224 + 0.406
        
        save_tensor_as_image(input_denorm, input_img_path)
        save_tensor_as_image(target_images[0], target_img_path)
        
        # Save comparison visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(np.transpose(input_denorm.cpu().numpy(), (1, 2, 0)))
        axes[0].set_title('Input Image (224x224)')
        axes[0].axis('off')
        
        axes[1].imshow(target_np_display)
        axes[1].set_title('Ground Truth (224x224)')
        axes[1].axis('off')
        
        axes[2].imshow(output_np_display)
        axes[2].set_title(f'TransUNet Output\nPSNR: {current_psnr:.2f}, FPR: {current_fpr:.4f}')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(recon_img_dir, f'comparison_{i+1:04d}.png'))
        plt.close()
        
        # Save the FPR mask for visualization
        plt.figure(figsize=(8, 8))
        plt.imshow(fpr_mask, cmap='hot')
        plt.colorbar(label='FPR Score')
        plt.title(f'FPR Mask - Score: {current_fpr:.4f}')
        plt.savefig(os.path.join(recon_img_dir, f'fpr_mask_{i+1:04d}.png'))
        plt.close()
        
        print(f"Image {i+1}: Time = {inference_time:.4f}s, PSNR = {current_psnr:.2f}, SSIM = {current_ssim:.4f}, "
              f"MAE = {current_mae:.4f}, MSE = {current_mse:.4f}, FPR = {current_fpr:.4f}")

    # Calculate average metrics and variance
    avg_inference_time = sum(inference_times) / len(inference_times)
    avg_psnr = sum(psnr_values) / len(psnr_values)
    avg_ssim = sum(ssim_values) / len(ssim_values)
    avg_mae = sum(mae_values) / len(mae_values)
    avg_mse = sum(mse_values) / len(mse_values)
    avg_fpr = sum(fpr_values) / len(fpr_values)
    
    # Calculate variance for each metric
    var_inference_time = np.var(inference_times, ddof=1) if len(inference_times) > 1 else 0
    var_psnr = np.var(psnr_values, ddof=1) if len(psnr_values) > 1 else 0
    var_ssim = np.var(ssim_values, ddof=1) if len(ssim_values) > 1 else 0
    var_mae = np.var(mae_values, ddof=1) if len(mae_values) > 1 else 0
    var_mse = np.var(mse_values, ddof=1) if len(mse_values) > 1 else 0
    var_fpr = np.var(fpr_values, ddof=1) if len(fpr_values) > 1 else 0
    
    # Calculate standard deviation for each metric
    std_inference_time = np.sqrt(var_inference_time)
    std_psnr = np.sqrt(var_psnr)
    std_ssim = np.sqrt(var_ssim)
    std_mae = np.sqrt(var_mae)
    std_mse = np.sqrt(var_mse)
    std_fpr = np.sqrt(var_fpr)
    
    # Print average metrics with variance
    print(f"\n{'='*60}")
    print(f"TRANSUNET EVALUATION RESULTS")
    print(f"{'='*60}")
    print(f"TransUNet Evaluation Results across {len(psnr_values)} images:")
    print(f"Inference Time: {avg_inference_time:.4f} ± {std_inference_time:.4f} seconds")
    print(f"PSNR: {avg_psnr:.2f} ± {std_psnr:.2f} dB")
    print(f"SSIM: {avg_ssim:.4f} ± {std_ssim:.4f}")
    print(f"MAE: {avg_mae:.4f} ± {std_mae:.4f}")
    print(f"MSE: {avg_mse:.4f} ± {std_mse:.4f}")
    print(f"FPR: {avg_fpr:.4f} ± {std_fpr:.4f}")

    # Create a DataFrame to store per-image metrics
    metrics_df = pd.DataFrame({
        'Image': [f'Image_{i+1}' for i in range(len(psnr_values))],
        'Inference_Time': inference_times,
        'PSNR': psnr_values,
        'SSIM': ssim_values,
        'MAE': mae_values,
        'MSE': mse_values,
        'FPR': fpr_values,
        'ReconstructedImagePath': [os.path.join(recon_img_dir, f'image_{i+1:04d}.png') for i in range(len(psnr_values))]
    })
    
    # Save metrics to CSV
    metrics_df.to_csv(f'{save_dir}/metrics_results.csv', index=False)
    
    # Include variance information in the CSV
    variance_df = pd.DataFrame({
        'Metric': ['Inference_Time', 'PSNR', 'SSIM', 'MAE', 'MSE', 'FPR'],
        'Average': [avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr],
        'Variance': [var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr],
        'Std_Dev': [std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr]
    })
    variance_df.to_csv(f'{save_dir}/metrics_variance.csv', index=False)
    
    # Create visualizations for TransUNet
    create_transunet_evaluation_plots(save_dir, inference_times, psnr_values, ssim_values, 
                                    mae_values, mse_values, fpr_values, 
                                    avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr,
                                    std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr,
                                    var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr)
    
    # Save metrics summary to text file
    save_transunet_metrics_summary(save_dir, model_path, len(psnr_values), t_high, t_low,
                                  avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr,
                                  std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr,
                                  var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr,
                                  recon_img_dir)
    
    # Return dictionary with all metrics including FPR and variance values
    return {
        'inference_time': avg_inference_time,
        'inference_time_std': std_inference_time,
        'inference_time_var': var_inference_time,
        'psnr': avg_psnr,
        'psnr_std': std_psnr,
        'psnr_var': var_psnr,
        'ssim': avg_ssim,
        'ssim_std': std_ssim,
        'ssim_var': var_ssim,
        'mae': avg_mae,
        'mae_std': std_mae,
        'mae_var': var_mae,
        'mse': avg_mse,
        'mse_std': std_mse,
        'mse_var': var_mse,
        'fpr': avg_fpr,
        'fpr_std': std_fpr,
        'fpr_var': var_fpr,
        'model_type': 'TransUNet',
        'images_processed': len(psnr_values),
        'reconstructed_images_dir': recon_img_dir
    }


def create_transunet_evaluation_plots(save_dir, inference_times, psnr_values, ssim_values, 
                                    mae_values, mse_values, fpr_values, 
                                    avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr,
                                    std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr,
                                    var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr):
    """Create comprehensive visualization plots for TransUNet evaluation metrics"""
    
    colors = ['blue', 'green', 'red', 'purple', 'orange', 'magenta']
    
    # Create a 3x2 subplot for all metrics including FPR
    fig, axs = plt.subplots(3, 2, figsize=(16, 15))
    
    # Get image indices for x-axis
    image_indices = [i+1 for i in range(len(psnr_values))]

    # Inference Time trend
    axs[0, 0].plot(image_indices, inference_times, marker='o', color=colors[4], linestyle='-')
    axs[0, 0].set_title('Inference Time Trend (Lower is better)')
    axs[0, 0].set_xlabel('Image Number')
    axs[0, 0].set_ylabel('Time (seconds)')
    axs[0, 0].grid(True, linestyle='--', alpha=0.7)
    axs[0, 0].text(image_indices[-1] * 0.05, max(inference_times) * 0.95, 
                f'Avg: {avg_inference_time:.4f} ± {std_inference_time:.4f}s', 
                bbox=dict(facecolor='white', alpha=0.8))

    # PSNR trend
    axs[0, 1].plot(image_indices, psnr_values, marker='o', color=colors[0], linestyle='-')
    axs[0, 1].set_title('PSNR Trend (Higher is better)')
    axs[0, 1].set_xlabel('Image Number')
    axs[0, 1].set_ylabel('PSNR (dB)')
    axs[0, 1].grid(True, linestyle='--', alpha=0.7)
    axs[0, 1].text(image_indices[-1] * 0.05, max(psnr_values) * 0.95, 
                f'Avg: {avg_psnr:.2f} ± {std_psnr:.2f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # SSIM trend
    axs[1, 0].plot(image_indices, ssim_values, marker='o', color=colors[1], linestyle='-')
    axs[1, 0].set_title('SSIM Trend (Higher is better)')
    axs[1, 0].set_xlabel('Image Number')
    axs[1, 0].set_ylabel('SSIM')
    axs[1, 0].grid(True, linestyle='--', alpha=0.7)
    axs[1, 0].text(image_indices[-1] * 0.05, max(ssim_values) * 0.95, 
                f'Avg: {avg_ssim:.4f} ± {std_ssim:.4f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # FPR trend
    axs[1, 1].plot(image_indices, fpr_values, marker='o', color=colors[5], linestyle='-')
    axs[1, 1].set_title('FPR Trend (Lower is better)')
    axs[1, 1].set_xlabel('Image Number')
    axs[1, 1].set_ylabel('FPR Score')
    axs[1, 1].grid(True, linestyle='--', alpha=0.7)
    axs[1, 1].text(image_indices[-1] * 0.05, max(fpr_values) * 0.95, 
                f'Avg: {avg_fpr:.4f} ± {std_fpr:.4f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # MSE trend
    axs[2, 0].plot(image_indices, mse_values, marker='o', color=colors[3], linestyle='-')
    axs[2, 0].set_title('MSE Trend (Lower is better)')
    axs[2, 0].set_xlabel('Image Number')
    axs[2, 0].set_ylabel('MSE')
    axs[2, 0].grid(True, linestyle='--', alpha=0.7)
    axs[2, 0].text(image_indices[-1] * 0.05, max(mse_values) * 0.95, 
                f'Avg: {avg_mse:.4f} ± {std_mse:.4f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # MAE trend
    axs[2, 1].plot(image_indices, mae_values, marker='o', color=colors[2], linestyle='-')
    axs[2, 1].set_title('MAE Trend (Lower is better)')
    axs[2, 1].set_xlabel('Image Number')
    axs[2, 1].set_ylabel('MAE')
    axs[2, 1].grid(True, linestyle='--', alpha=0.7)
    axs[2, 1].text(image_indices[-1] * 0.05, max(mae_values) * 0.95, 
                f'Avg: {avg_mae:.4f} ± {std_mae:.4f}', 
                bbox=dict(facecolor='white', alpha=0.8))

    # Save trend plot
    plt.suptitle('TransUNet Evaluation - Comprehensive Metrics with FPR', fontsize=16)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/transunet_comprehensive_metrics_with_fpr.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Create a box plot for all metrics to visualize variance
    plt.figure(figsize=(15, 10))
    
    # Normalize metrics for better visualization in one plot
    norm_inference = [x/max(inference_times) for x in inference_times] if max(inference_times) > 0 else [0]*len(inference_times)
    norm_psnr = [x/max(psnr_values) for x in psnr_values] if max(psnr_values) > 0 else [0]*len(psnr_values)
    norm_ssim = ssim_values  # SSIM is already normalized between 0-1
    norm_mae = [x/max(mae_values) for x in mae_values] if max(mae_values) > 0 else [0]*len(mae_values)
    norm_mse = [x/max(mse_values) for x in mse_values] if max(mse_values) > 0 else [0]*len(mse_values)
    norm_fpr = [x/max(fpr_values) for x in fpr_values] if max(fpr_values) > 0 else [0]*len(fpr_values)
    
    # Create box plot data
    box_data = [norm_psnr, norm_ssim, norm_inference, norm_mae, norm_mse, norm_fpr]
    
    # Create boxplot
    bp = plt.boxplot(box_data, patch_artist=True)
    
    # Set colors for each box
    colors = ['lightblue', 'lightgreen', 'coral', 'orchid', 'wheat', 'pink']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    # Set labels and title
    plt.title('TransUNet - Distribution of Normalized Metrics (Box Plot)', fontsize=14)
    plt.xlabel('Metrics')
    plt.ylabel('Normalized Values')
    plt.xticks([1, 2, 3, 4, 5, 6], ['PSNR', 'SSIM', 'Inference', 'MAE', 'MSE', 'FPR'])
    plt.grid(True, linestyle='--', alpha=0.3)
    
    # Add annotation about variance
    plt.figtext(0.5, 0.01, 
                f'Variance: PSNR: {var_psnr:.4f}, SSIM: {var_ssim:.4f}, Time: {var_inference_time:.4f}, '
                f'MAE: {var_mae:.4f}, MSE: {var_mse:.4f}, FPR: {var_fpr:.4f}',
                ha='center', fontsize=10)
    
    # Save the plot
    plt.tight_layout()
    plt.savefig(f'{save_dir}/transunet_metrics_distribution.png', dpi=300, bbox_inches='tight')
    plt.close()


def save_transunet_metrics_summary(save_dir, model_path, num_images, t_high, t_low,
                                  avg_inference_time, avg_psnr, avg_ssim, avg_mae, avg_mse, avg_fpr,
                                  std_inference_time, std_psnr, std_ssim, std_mae, std_mse, std_fpr,
                                  var_inference_time, var_psnr, var_ssim, var_mae, var_mse, var_fpr,
                                  recon_img_dir):
    """Save comprehensive TransUNet metrics summary to text file"""
    
    with open(f'{save_dir}/transunet_metrics_summary_with_variance.txt', 'w') as f:
        f.write("="*80 + "\n")
        f.write("TRANSUNET MODEL EVALUATION REPORT\n")
        f.write("="*80 + "\n\n")
        
        f.write(f"Model Type: TransUNet\n")
        f.write(f"Model Path: {model_path}\n")
        f.write(f"Input/Output Size: 224x224\n")
        f.write(f"Total Images Processed: {num_images}\n")
        f.write(f"FPR Threshold Parameters: t_high={t_high}, t_low={t_low}\n\n")
        
        f.write("AVERAGE METRICS WITH STATISTICAL ANALYSIS:\n")
        f.write("-" * 50 + "\n")
        f.write(f"Inference Time: {avg_inference_time:.4f} ± {std_inference_time:.4f} seconds (var: {var_inference_time:.6f})\n")
        f.write(f"PSNR: {avg_psnr:.4f} ± {std_psnr:.4f} dB (var: {var_psnr:.6f})\n")
        f.write(f"SSIM: {avg_ssim:.4f} ± {std_ssim:.4f} (var: {var_ssim:.6f})\n")
        f.write(f"MAE: {avg_mae:.4f} ± {std_mae:.4f} (var: {var_mae:.6f})\n")
        f.write(f"MSE: {avg_mse:.4f} ± {std_mse:.4f} (var: {var_mse:.6f})\n")
        f.write(f"FPR: {avg_fpr:.4f} ± {std_fpr:.4f} (var: {var_fpr:.6f})\n\n")
        
        f.write("TRANSUNET SPECIFIC NOTES:\n")
        f.write("-" * 25 + "\n")
        f.write("• TransUNet uses 224x224 input/output resolution\n")
        f.write("• Both input and target images are resized to match model requirements\n")
        f.write("• This differs from other models that may use 449x449 output\n")
        f.write("• Metrics are calculated on the 224x224 resolution\n\n")
        
        f.write("INTERPRETATION:\n")
        f.write("-" * 20 + "\n")
        f.write("• PSNR (Peak Signal-to-Noise Ratio): Higher values indicate better quality\n")
        f.write("• SSIM (Structural Similarity Index): Range [0,1], higher values indicate better structural similarity\n")
        f.write("• MAE (Mean Absolute Error): Lower values indicate better accuracy\n")
        f.write("• MSE (Mean Squared Error): Lower values indicate better accuracy\n")
        f.write("• FPR (False Positive Regions): Lower values indicate less hallucination\n")
        f.write("• Inference Time: Lower values indicate faster processing\n\n")
        
        f.write(f"Reconstructed images saved to: {recon_img_dir}\n")
        f.write(f"Generated plots saved to: {save_dir}\n\n")
        
        f.write("FILES GENERATED:\n")
        f.write("-" * 15 + "\n")
        f.write("• metrics_results.csv - Per-image detailed metrics\n")
        f.write("• metrics_variance.csv - Statistical summary\n")
        f.write("• transunet_comprehensive_metrics_with_fpr.png - Trend plots\n")
        f.write("• transunet_metrics_distribution.png - Distribution analysis\n")
        f.write("• reconstructed_images/ - All output images and comparisons\n")

def run_multi_model_evaluation(models_config, test_response_path, test_gt_path, 
                              start_index, end_index, t_high=0.5, t_low=0.2):
    """
    Run evaluation on multiple models and create comparison report
    
    Parameters
    ----------
    models_config : list
        List of dictionaries containing model configuration:
        [{'type': 'model_name', 'path': 'model_path', 'save_dir': 'save_directory'}, ...]
    test_response_path : str
        Path to test response images
    test_gt_path : str
        Path to ground truth images
    start_index : int
        Starting index for evaluation
    end_index : int
        Ending index for evaluation
    t_high : float
        FPR high threshold
    t_low : float
        FPR low threshold
    """
    
    print("="*80)
    print("MULTI-MODEL EVALUATION STARTED")
    print("="*80)
    
    results_summary = []
    
    for i, model_config in enumerate(models_config):
        print(f"\n[{i+1}/{len(models_config)}] Evaluating {model_config['type']}...")
        print("-" * 60)
        
        # Create a unique directory for each model type
        model_specific_dir = os.path.join(
            os.path.dirname(model_config['save_dir']),
            f"{model_config['type']}_results"
        )
        
        # Run evaluation
        try:
            result = comprehensive_eval_new2(
                test_response_file_path=test_response_path,
                test_gt_file_path=test_gt_path,
                start_idx=start_index,
                end_idx=end_index,
                model_type=model_config['type'],
                model_path=model_config['path'],
                save_dir=model_specific_dir,
                t_high=t_high,
                t_low=t_low
            )
            
            if result is not None:
                results_summary.append(result)
                print(f"✓ Successfully evaluated {model_config['type']}")
            else:
                print(f"✗ Failed to evaluate {model_config['type']}")
                
        except Exception as e:
            print(f"✗ Error evaluating {model_config['type']}: {str(e)}")
            continue
    
    # Create comparison report
    if results_summary:
        create_comparison_report(results_summary, './evaluation_results/comparison_report')
        print(f"\n✓ Comparison report saved to: ./evaluation_results/comparison_report")
    
    print("\n" + "="*80)
    print("MULTI-MODEL EVALUATION COMPLETED")
    print("="*80)
    
    return results_summary


def create_comparison_report(results_summary, save_dir):
    """Create a comprehensive comparison report across all models"""
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Create comparison DataFrame
    comparison_data = []
    for result in results_summary:
        comparison_data.append({
            'Model': result['model_type'],
            'PSNR_avg': result['psnr'],
            'PSNR_std': result['psnr_std'],
            'SSIM_avg': result['ssim'],
            'SSIM_std': result['ssim_std'],
            'MAE_avg': result['mae'],
            'MAE_std': result['mae_std'],
            'MSE_avg': result['mse'],
            'MSE_std': result['mse_std'],
            'FPR_avg': result['fpr'],
            'FPR_std': result['fpr_std'],
            'Inference_Time_avg': result['inference_time'],
            'Inference_Time_std': result['inference_time_std'],
            'Images_Processed': result['images_processed']
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    comparison_df.to_csv(f'{save_dir}/model_comparison.csv', index=False)
    
    # Create comparison plots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    models = comparison_df['Model'].tolist()
    
    # PSNR comparison
    psnr_avg = comparison_df['PSNR_avg'].tolist()
    psnr_std = comparison_df['PSNR_std'].tolist()
    axes[0, 0].bar(models, psnr_avg, yerr=psnr_std, capsize=5, alpha=0.7)
    axes[0, 0].set_title('PSNR Comparison (Higher is better)')
    axes[0, 0].set_ylabel('PSNR (dB)')
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # SSIM comparison
    ssim_avg = comparison_df['SSIM_avg'].tolist()
    ssim_std = comparison_df['SSIM_std'].tolist()
    axes[0, 1].bar(models, ssim_avg, yerr=ssim_std, capsize=5, alpha=0.7, color='green')
    axes[0, 1].set_title('SSIM Comparison (Higher is better)')
    axes[0, 1].set_ylabel('SSIM')
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    # FPR comparison
    fpr_avg = comparison_df['FPR_avg'].tolist()
    fpr_std = comparison_df['FPR_std'].tolist()
    axes[0, 2].bar(models, fpr_avg, yerr=fpr_std, capsize=5, alpha=0.7, color='red')
    axes[0, 2].set_title('FPR Comparison (Lower is better)')
    axes[0, 2].set_ylabel('FPR Score')
    axes[0, 2].tick_params(axis='x', rotation=45)
    
    # MAE comparison
    mae_avg = comparison_df['MAE_avg'].tolist()
    mae_std = comparison_df['MAE_std'].tolist()
    axes[1, 0].bar(models, mae_avg, yerr=mae_std, capsize=5, alpha=0.7, color='orange')
    axes[1, 0].set_title('MAE Comparison (Lower is better)')
    axes[1, 0].set_ylabel('MAE')
    axes[1, 0].tick_params(axis='x', rotation=45)
    
    # MSE comparison
    mse_avg = comparison_df['MSE_avg'].tolist()
    mse_std = comparison_df['MSE_std'].tolist()
    axes[1, 1].bar(models, mse_avg, yerr=mse_std, capsize=5, alpha=0.7, color='purple')
    axes[1, 1].set_title('MSE Comparison (Lower is better)')
    axes[1, 1].set_ylabel('MSE')
    axes[1, 1].tick_params(axis='x', rotation=45)
    
    # Inference Time comparison
    time_avg = comparison_df['Inference_Time_avg'].tolist()
    time_std = comparison_df['Inference_Time_std'].tolist()
    axes[1, 2].bar(models, time_avg, yerr=time_std, capsize=5, alpha=0.7, color='brown')
    axes[1, 2].set_title('Inference Time Comparison (Lower is better)')
    axes[1, 2].set_ylabel('Time (seconds)')
    axes[1, 2].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.savefig(f'{save_dir}/model_comparison_plots.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Save comparison summary
    with open(f'{save_dir}/comparison_summary.txt', 'w') as f:
        f.write("MODEL COMPARISON SUMMARY\n")
        f.write("="*50 + "\n\n")
        
        for result in results_summary:
            f.write(f"Model: {result['model_type']}\n")
            f.write(f"  PSNR: {result['psnr']:.4f} ± {result['psnr_std']:.4f}\n")
            f.write(f"  SSIM: {result['ssim']:.4f} ± {result['ssim_std']:.4f}\n")
            f.write(f"  FPR:  {result['fpr']:.4f} ± {result['fpr_std']:.4f}\n")
            f.write(f"  Time: {result['inference_time']:.4f} ± {result['inference_time_std']:.4f}s\n")
            f.write("-" * 30 + "\n")
        
        # Find best performing models
        best_psnr = max(results_summary, key=lambda x: x['psnr'])
        best_ssim = max(results_summary, key=lambda x: x['ssim'])
        best_fpr = min(results_summary, key=lambda x: x['fpr'])
        best_time = min(results_summary, key=lambda x: x['inference_time'])
        
        f.write("\nBEST PERFORMING MODELS:\n")
        f.write("="*30 + "\n")
        f.write(f"Best PSNR: {best_psnr['model_type']} ({best_psnr['psnr']:.4f})\n")
        f.write(f"Best SSIM: {best_ssim['model_type']} ({best_ssim['ssim']:.4f})\n")
        f.write(f"Best FPR:  {best_fpr['model_type']} ({best_fpr['fpr']:.4f})\n")
        f.write(f"Fastest:   {best_time['model_type']} ({best_time['inference_time']:.4f}s)\n")


# ============================================================================
# MAIN EXECUTION EXAMPLE
# ============================================================================

if __name__ == "__main__":
    # Example configuration for multiple models
    models_to_compare = [
        {
            'type': 'Restormer',
            'path': '/pretrained_param/Restormer_final.pth',
            'save_dir': './evaluation_results/comparison/Restormer'
        },
        {
            'type': 'TRUST',
            'path': '/pretrained_param/TRUST_final.pth',
            'save_dir': './evaluation_results/comparison/vunt_early_skip'
        },
        {
            'type': 'unet',
            'path': '/pretrained_param/UNet_final.pth',
            'save_dir': './evaluation_results/comparison/unet'
        }
    ]

    # Dataset paths and parameters
    test_response_path = '/data/'
    test_gt_path = '/data/'
    start_index = 1
    end_index = 5

    # Eval for the TRUST, Restormer, Unet

    # Run multi-model evaluation
    results = run_multi_model_evaluation(
        models_config=models_to_compare,
        test_response_path=test_response_path,
        test_gt_path=test_gt_path,
        start_index=start_index,
        end_index=end_index,
        t_high=0.5,
        t_low=0.2
    )


    # Eval for the TRUST, Restormer, Unet
    '''
    results = comprehensive_eval_transunet(
         test_response_file_path=test_response_path,
         test_gt_file_path=test_gt_path,
         start_idx=start_index,
         end_idx=end_index,
         model_path='/pretrained_param/TransUNEt_final.pth',
         save_dir='./evaluation_results/TransUNet_results',
         t_high=0.5,
         t_low=0.2
    )
    '''
    
    print("\nEvaluation completed successfully!")
    print(f"Results for {len(results)} models have been saved.")
    
