# #!/usr/bin/env python3
# import os
# # Fix OpenMP library conflict
# os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import DataLoader
# import matplotlib.pyplot as plt
# from pathlib import Path
# from PIL import Image
# import json
# import time
# from tqdm import tqdm
# import seaborn as sns
# from pytorch_msssim import ssim, ms_ssim
# import pandas as pd
# from transformers import ViTModel, ViTConfig

# # ========== SHARED DATASET CLASS ==========
# def make_gaussian_random_orthonormal_rows(h=64, w=64, seed=42):
#     """Generate a matrix A of size [h, w] where rows are orthonormal."""
#     if seed is not None:
#         torch.manual_seed(seed)
#     A = torch.randn(h, w)
#     Q, R = torch.linalg.qr(A.T)
#     return Q.T

# class PatchwiseOrthonormalDataset:
#     """Dataset that applies patch-wise orthonormal transformation to images."""
#     def __init__(self, data_dir, seed=42, verbose=False):
#         self.data_dir = data_dir
#         self.A = make_gaussian_random_orthonormal_rows(h=64, w=64, seed=seed)
        
#         self.data_path = Path(data_dir)
#         if not self.data_path.exists():
#             raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
#         image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPEG', '.JPG'}
#         self.image_files = [f for f in self.data_path.iterdir() 
#                            if f.is_file() and f.suffix in image_extensions]
        
#         if len(self.image_files) == 0:
#             raise ValueError(f"No images found in {data_dir}")
        
#         if verbose:
#             print(f"Loaded {len(self.image_files)} images from {data_dir}")

#     def __len__(self):
#         return len(self.image_files)

#     def resize_min_side(self, img, min_side=224):
#         w, h = img.size
#         s = min_side / min(w, h)
#         return img.resize((int(round(w*s)), int(round(h*s))), Image.Resampling.LANCZOS)

#     def center_crop(self, img, size=224):
#         w, h = img.size
#         left = (w - size) // 2
#         top = (h - size) // 2
#         return img.crop((left, top, left + size, top + size))

#     def preprocess_image(self, img):
#         img = img.convert("RGB")
#         img_resized = self.resize_min_side(img, 224)
#         img_crop = self.center_crop(img_resized, 224)
#         x = np.array(img_crop).astype(np.float32) / 255.0
#         return x

#     def process_image_with_orthonormal_masks(self, np_img, mask_matrix):
#         img_tensor = torch.from_numpy(np_img).float()
        
#         if img_tensor.shape[2] == 3:
#             img_gray = img_tensor.mean(dim=2)
#         else:
#             img_gray = img_tensor
        
#         patches = img_gray.unfold(0, 8, 8).unfold(1, 8, 8)
#         transformed_patches = torch.zeros(28, 28, 16)
        
#         for i in range(28):
#             for j in range(28):
#                 patch_flat = patches[i, j].flatten()
#                 transformed = mask_matrix @ patch_flat
#                 transformed = transformed[torch.randperm(transformed.shape[0])[:transformed.shape[0] // 4]]
#                 transformed_patches[i, j] = transformed
        
#         return transformed_patches

#     def reconstruct_masked_image(self, transformed_patches):
#         masked_image = torch.zeros(112, 112)
        
#         for i in range(14):
#             for j in range(14):
#                 transformed_patch = transformed_patches[i, j]
#                 patch_4x4 = transformed_patch.reshape(4, 4)
#                 patch_8x8 = F.interpolate(patch_4x4.unsqueeze(0).unsqueeze(0), 
#                                         size=(8, 8), mode='bilinear', align_corners=True)[0, 0]
                
#                 start_h = i * 8
#                 end_h = start_h + 8
#                 start_w = j * 8
#                 end_w = start_w + 8
                
#                 masked_image[start_h:end_h, start_w:end_w] = patch_8x8
        
#         return masked_image

#     def apply_patchwise_orthonormal_transform(self, x):
#         y_channels = []
        
#         for c in range(3):
#             single_channel = x[..., c]
#             transformed_patches = self.process_image_with_orthonormal_masks(
#                 np.expand_dims(single_channel, axis=2), self.A
#             )
#             masked_channel = self.reconstruct_masked_image(transformed_patches)
#             y_channels.append(masked_channel.numpy())
        
#         y = np.stack(y_channels, axis=2)
#         y_min = y.min()
#         y_max = y.max()
#         y_norm = (y - y_min) / (y_max - y_min + 1e-8)
        
#         return y_norm

#     def __getitem__(self, idx):
#         img_path = self.image_files[idx]
        
#         try:
#             img = Image.open(img_path)
#         except Exception as e:
#             print(f"Warning: Could not load image {img_path}: {e}")
#             img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        
#         x = self.preprocess_image(img)
#         y = self.apply_patchwise_orthonormal_transform(x)
        
#         x_tensor = torch.from_numpy(x).permute(2, 0, 1)
#         y_tensor = torch.from_numpy(y).permute(2, 0, 1)
        
#         return y_tensor, x_tensor, str(img_path)

# # ========== TRANSUNET MODEL ==========
# class PatchEmbedding(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
#         super().__init__()
#         self.img_size = img_size
#         self.patch_size = patch_size
#         self.n_patches = (img_size // patch_size) ** 2
#         self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
#     def forward(self, x):
#         x = self.proj(x)
#         B, C, H, W = x.shape
#         x = x.flatten(2).transpose(1, 2)
#         return x, (H, W)

# class MultiHeadSelfAttention(nn.Module):
#     def __init__(self, embed_dim, num_heads, dropout=0.1):
#         super().__init__()
#         self.embed_dim = embed_dim
#         self.num_heads = num_heads
#         self.head_dim = embed_dim // num_heads
        
#         assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
#         self.qkv = nn.Linear(embed_dim, embed_dim * 3)
#         self.proj = nn.Linear(embed_dim, embed_dim)
#         self.dropout = nn.Dropout(dropout)
        
#     def forward(self, x):
#         B, N, C = x.shape
        
#         qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
#         q, k, v = qkv[0], qkv[1], qkv[2]
        
#         attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
#         attn = F.softmax(attn, dim=-1)
#         attn = self.dropout(attn)
        
#         x = (attn @ v).transpose(1, 2).reshape(B, N, C)
#         x = self.proj(x)
#         return x

# class TransformerBlock(nn.Module):
#     def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(embed_dim)
#         self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
#         self.norm2 = nn.LayerNorm(embed_dim)
        
#         mlp_hidden_dim = int(embed_dim * mlp_ratio)
#         self.mlp = nn.Sequential(
#             nn.Linear(embed_dim, mlp_hidden_dim),
#             nn.GELU(),
#             nn.Dropout(dropout),
#             nn.Linear(mlp_hidden_dim, embed_dim),
#             nn.Dropout(dropout)
#         )
        
#     def forward(self, x):
#         x = x + self.attn(self.norm1(x))
#         x = x + self.mlp(self.norm2(x))
#         return x

# class VisionTransformerEncoder(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, 
#                  depth=12, num_heads=12, mlp_ratio=4, dropout=0.1):
#         super().__init__()
#         self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
#         self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.n_patches, embed_dim) * 0.02)
#         self.dropout = nn.Dropout(dropout)
        
#         self.blocks = nn.ModuleList([
#             TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
#             for _ in range(depth)
#         ])
        
#         self.norm = nn.LayerNorm(embed_dim)
        
#     def forward(self, x):
#         x, (H, W) = self.patch_embed(x)
        
#         x = x + self.pos_embed
#         x = self.dropout(x)
        
#         features = []
#         for i, block in enumerate(self.blocks):
#             x = block(x)
#             if i in [2, 5, 8]:
#                 features.append(x)
        
#         x = self.norm(x)
#         features.append(x)
        
#         return features, (H, W)

# class ConvBlock(nn.Module):
#     def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
#         super().__init__()
#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True)
#         )
        
#     def forward(self, x):
#         return self.conv(x)

# class UpBlock(nn.Module):
#     def __init__(self, in_channels, out_channels, skip_channels=0):
#         super().__init__()
#         self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
#         self.conv = ConvBlock(in_channels // 2 + skip_channels, out_channels)
        
#     def forward(self, x, skip=None):
#         x = self.up(x)
#         if skip is not None:
#             x = torch.cat([x, skip], dim=1)
#         return self.conv(x)

# class TransUNet(nn.Module):
#     def __init__(self, img_size=224, patch_size=16, in_channels=3, out_channels=3,
#                  embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, dropout=0.1):
#         super().__init__()
        
#         self.input_prep = nn.Sequential(
#             nn.Upsample(size=(img_size, img_size), mode='bilinear', align_corners=True),
#             nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
#             nn.BatchNorm2d(64),
#             nn.ReLU(inplace=True)
#         )
        
#         self.vit_encoder = VisionTransformerEncoder(
#             img_size=img_size, patch_size=patch_size, in_channels=64,
#             embed_dim=embed_dim, depth=depth, num_heads=num_heads,
#             mlp_ratio=mlp_ratio, dropout=dropout
#         )
        
#         self.cnn_enc1 = ConvBlock(64, 64)
#         self.cnn_enc2 = ConvBlock(64, 128)
#         self.cnn_enc3 = ConvBlock(128, 256)
#         self.cnn_enc4 = ConvBlock(256, 512)
        
#         self.pool = nn.MaxPool2d(2)
        
#         self.vit_to_cnn = nn.Sequential(
#             nn.Linear(embed_dim, 512),
#             nn.ReLU(inplace=True),
#             nn.Linear(512, 512)
#         )
        
#         self.dec4 = UpBlock(1024, 256, skip_channels=256)
#         self.dec3 = UpBlock(256, 128, skip_channels=128)
#         self.dec2 = UpBlock(128, 64, skip_channels=64)
#         self.dec1 = UpBlock(64, 64, skip_channels=64)
        
#         self.final_conv = nn.Sequential(
#             nn.Conv2d(64, 32, kernel_size=3, padding=1),
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(32, out_channels, kernel_size=1),
#             nn.Sigmoid()
#         )
        
#         self.skip_connection = nn.Sequential(
#             nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(32, out_channels, kernel_size=3, padding=1),
#             nn.Tanh()
#         )
        
#     def forward(self, x):
#         original_input = x
        
#         x = self.input_prep(x)
        
#         skip_features = self.skip_connection(original_input)
#         skip_features = F.interpolate(skip_features, size=(224, 224), 
#                                     mode='bilinear', align_corners=True)
        
#         enc1 = self.cnn_enc1(x)
#         enc2 = self.cnn_enc2(self.pool(enc1))
#         enc3 = self.cnn_enc3(self.pool(enc2))
#         enc4 = self.cnn_enc4(self.pool(enc3))
        
#         vit_features, (H, W) = self.vit_encoder(x)
        
#         final_vit_features = vit_features[-1]
#         B, N, C = final_vit_features.shape
        
#         vit_proj = self.vit_to_cnn(final_vit_features)
        
#         vit_spatial = vit_proj.transpose(1, 2).reshape(B, 512, H, W)
        
#         vit_spatial = F.interpolate(vit_spatial, size=enc4.shape[-2:], 
#                                   mode='bilinear', align_corners=True)
        
#         bottleneck = torch.cat([vit_spatial, enc4], dim=1)
        
#         dec4_up = self.dec4.up(bottleneck)
#         dec4_concat = torch.cat([dec4_up, enc3], dim=1)
#         dec4_out = self.dec4.conv(dec4_concat)
        
#         dec3_up = self.dec3.up(dec4_out)
#         dec3_concat = torch.cat([dec3_up, enc2], dim=1)
#         dec3_out = self.dec3.conv(dec3_concat)
        
#         dec2_up = self.dec2.up(dec3_out)
#         dec2_concat = torch.cat([dec2_up, enc1], dim=1)
#         dec2_out = self.dec2.conv(dec2_concat)
        
#         dec1_up = self.dec1.up(dec2_out)
#         dec1_up = F.interpolate(dec1_up, size=(224, 224), mode='bilinear', align_corners=True)
#         dec1_concat = torch.cat([dec1_up, enc1], dim=1)
#         dec1_out = self.dec1.conv(dec1_concat)
        
#         output = self.final_conv(dec1_out)
        
#         output = output + skip_features
        
#         output = torch.clamp(output, 0, 1)
        
#         return output

# # ========== VIT-UNET MODEL ==========
# class ViTUNetForInverseProblem(nn.Module):
#     def __init__(self, pretrained_model_name="google/vit-base-patch16-224", output_size=(224, 224)):
#         super().__init__()
    
#         cfg = ViTConfig.from_pretrained(pretrained_model_name)
#         cfg.add_pooling_layer = False
#         self.vit = ViTModel.from_pretrained(pretrained_model_name, config=cfg, ignore_mismatched_sizes=True)
    
#         self.output_size = output_size
#         self.hidden_dim = 768

#         self.input_upsample = nn.Sequential(
#             nn.Conv2d(3, 16, 3, padding=1),
#             nn.BatchNorm2d(16), nn.ReLU(True),
#             nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
#             nn.Conv2d(16, 3, 3, padding=1),
#             nn.Sigmoid()
#         )

#         self.skip_upsample = nn.Sequential(
#             nn.Conv2d(3, 32, 3, padding=1),
#             nn.BatchNorm2d(32), nn.ReLU(True),
#             nn.Conv2d(32, 64, 3, padding=1), 
#             nn.BatchNorm2d(64), nn.ReLU(True),
#             nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
#             nn.Conv2d(64, 32, 3, padding=1),
#             nn.BatchNorm2d(32), nn.ReLU(True),
#             nn.Conv2d(32, 3, 3, padding=1),
#             nn.Tanh()
#         )

#         self.adaptive_pool1 = nn.AdaptiveAvgPool2d((28, 28))
#         self.adaptive_pool2 = nn.AdaptiveAvgPool2d((14, 14))
#         self.adaptive_pool3 = nn.AdaptiveAvgPool2d((7, 7))
#         self.adaptive_pool_final = nn.AdaptiveAvgPool2d((7, 7))

#         self.skip_conn1 = nn.Conv2d(self.hidden_dim, 128, kernel_size=1)
#         self.skip_conn2 = nn.Conv2d(self.hidden_dim, 256, kernel_size=1)
#         self.skip_conn3 = nn.Conv2d(self.hidden_dim, 512, kernel_size=1)

#         self.up1 = nn.Sequential(
#             nn.Upsample(size=(14, 14), mode='bilinear', align_corners=True),
#             nn.Conv2d(self.hidden_dim, 512, 3, padding=1), 
#             nn.BatchNorm2d(512), nn.ReLU(True)
#         )
#         self.up2 = nn.Sequential(
#             nn.Upsample(size=(28, 28), mode='bilinear', align_corners=True),
#             nn.Conv2d(512, 256, 3, padding=1), 
#             nn.BatchNorm2d(256), nn.ReLU(True)
#         )
#         self.up3 = nn.Sequential(
#             nn.Upsample(size=(56, 56), mode='bilinear', align_corners=True),
#             nn.Conv2d(256, 128, 3, padding=1), 
#             nn.BatchNorm2d(128), nn.ReLU(True)
#         )
#         self.up4 = nn.Sequential(
#             nn.Upsample(size=(112, 112), mode='bilinear', align_corners=True),
#             nn.Conv2d(128, 64, 3, padding=1), 
#             nn.BatchNorm2d(64), nn.ReLU(True)
#         )
    
#         self.final = nn.Sequential(
#             nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
#             nn.Conv2d(64, 32, 3, padding=1),
#             nn.BatchNorm2d(32), nn.ReLU(True),
#             nn.Conv2d(32, 3, 3, padding=1),
#             nn.Sigmoid()
#         )

#         self.fusion = nn.Sequential(
#             nn.Conv2d(6, 32, 3, padding=1),
#             nn.BatchNorm2d(32), nn.ReLU(True),
#             nn.Conv2d(32, 3, 3, padding=1),
#             nn.Sigmoid()
#         )

#     def _extract(self, x3):
#         out = self.vit(x3, output_hidden_states=True)
#         early = out.hidden_states[3]
#         mid   = out.hidden_states[6]
#         late  = out.hidden_states[9]
#         last  = out.last_hidden_state
#         return early, mid, late, last

#     def _to_spatial(self, tokens):
#         B, N, C = tokens.shape
#         HW = int((N - 1) ** 0.5)
#         t = tokens[:, 1:, :].permute(0, 2, 1)
#         return t.reshape(B, C, HW, HW)

#     def forward(self, x):
#         skip_features = self.skip_upsample(x)
        
#         if x.shape[-2:] != (224, 224):
#             x_upsampled = self.input_upsample(x)
#         else:
#             x_upsampled = x

#         e, m, l, f = self._extract(x_upsampled)
#         e, m, l, f = map(self._to_spatial, (e, m, l, f))

#         skip1 = self.adaptive_pool1(e)
#         skip2 = self.adaptive_pool2(m)
#         skip3 = self.adaptive_pool3(l)
#         x_feat = self.adaptive_pool_final(f)

#         x_feat = self.up1(x_feat)
#         x_feat = x_feat + F.interpolate(self.skip_conn3(skip3), size=(14, 14), mode='bilinear', align_corners=True)
    
#         x_feat = self.up2(x_feat)
#         x_feat = x_feat + F.interpolate(self.skip_conn2(skip2), size=(28, 28), mode='bilinear', align_corners=True)
    
#         x_feat = self.up3(x_feat)
#         x_feat = x_feat + F.interpolate(self.skip_conn1(skip1), size=(56, 56), mode='bilinear', align_corners=True)
    
#         x_feat = self.up4(x_feat)
#         vit_output = self.final(x_feat)
        
#         combined = torch.cat([vit_output, skip_features], dim=1)
#         out = self.fusion(combined)
    
#         if self.output_size != (224, 224):
#             out = F.interpolate(out, size=self.output_size, mode='bilinear', align_corners=True)
        
#         return out

# # ========== TRADITIONAL U-NET MODEL ==========
# class DoubleConv(nn.Module):
#     def __init__(self, in_channels, out_channels, mid_channels=None):
#         super().__init__()
#         if not mid_channels:
#             mid_channels = out_channels
#         self.double_conv = nn.Sequential(
#             nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
#             nn.BatchNorm2d(mid_channels),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU(inplace=True)
#         )

#     def forward(self, x):
#         return self.double_conv(x)

# class Down(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super().__init__()
#         self.maxpool_conv = nn.Sequential(
#             nn.MaxPool2d(2),
#             DoubleConv(in_channels, out_channels)
#         )

#     def forward(self, x):
#         return self.maxpool_conv(x)

# class Up(nn.Module):
#     def __init__(self, in_channels, out_channels, bilinear=True):
#         super().__init__()

#         if bilinear:
#             self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
#             self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
#         else:
#             self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
#             self.conv = DoubleConv(in_channels, out_channels)

#     def forward(self, x1, x2):
#         x1 = self.up(x1)
#         diffY = x2.size()[2] - x1.size()[2]
#         diffX = x2.size()[3] - x1.size()[3]

#         x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
#                        diffY // 2, diffY - diffY // 2])
        
#         x = torch.cat([x2, x1], dim=1)
#         return self.conv(x)

# class OutConv(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(OutConv, self).__init__()
#         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

#     def forward(self, x):
#         return self.conv(x)

# class UNetForInverseProblem(nn.Module):
#     def __init__(self, n_channels=3, n_classes=3, bilinear=False):
#         super(UNetForInverseProblem, self).__init__()
#         self.n_channels = n_channels
#         self.n_classes = n_classes
#         self.bilinear = bilinear

#         self.input_prep = nn.Sequential(
#             nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
#             nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=1),
#             nn.BatchNorm2d(n_channels),
#             nn.ReLU(inplace=True)
#         )

#         self.inc = DoubleConv(n_channels, 64)
#         self.down1 = Down(64, 128)
#         self.down2 = Down(128, 256)
#         self.down3 = Down(256, 512)
#         factor = 2 if bilinear else 1
#         self.down4 = Down(512, 1024 // factor)
        
#         self.up1 = Up(1024, 512 // factor, bilinear)
#         self.up2 = Up(512, 256 // factor, bilinear)
#         self.up3 = Up(256, 128 // factor, bilinear)
#         self.up4 = Up(128, 64, bilinear)
        
#         self.outc = OutConv(64, n_classes)
        
#         self.final_activation = nn.Sigmoid()

#         self.skip_connection = nn.Sequential(
#             nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
#             nn.BatchNorm2d(32),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(32, n_classes, kernel_size=3, padding=1),
#             nn.Tanh()
#         )

#     def forward(self, x):
#         original_input = x
        
#         x = self.input_prep(x)
        
#         skip_features = self.skip_connection(x)
        
#         x1 = self.inc(x)
#         x2 = self.down1(x1)
#         x3 = self.down2(x2)
#         x4 = self.down3(x3)
#         x5 = self.down4(x4)
        
#         x = self.up1(x5, x4)
#         x = self.up2(x, x3)
#         x = self.up3(x, x2)
#         x = self.up4(x, x1)
        
#         logits = self.outc(x)
        
#         output = logits + skip_features
        
#         output = self.final_activation(output)
        
#         return output

# # ========== EVALUATION FUNCTIONS ==========
# def calculate_psnr(pred, target):
#     """Calculate PSNR between predicted and target images"""
#     mse = torch.mean((pred - target) ** 2)
#     if mse == 0:
#         return float('inf')
#     return 20 * torch.log10(1.0 / torch.sqrt(mse))

# def calculate_metrics(pred, target):
#     """Calculate comprehensive metrics"""
#     pred_clamped = torch.clamp(pred, 0, 1)
#     target_clamped = torch.clamp(target, 0, 1)
    
#     # Basic metrics
#     mse = F.mse_loss(pred, target).item()
#     l1 = F.l1_loss(pred, target).item()
    
#     # PSNR
#     psnr = calculate_psnr(pred_clamped, target_clamped).item()
    
#     # SSIM metrics
#     ssim_val = ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True).item()
#     ms_ssim_val = ms_ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True).item()
    
#     return {
#         'mse': mse,
#         'l1': l1,
#         'psnr': psnr,
#         'ssim': ssim_val,
#         'ms_ssim': ms_ssim_val
#     }

# def evaluate_model(model, test_loader, device, model_name):
#     """Evaluate a single model and return results"""
#     print(f"Evaluating {model_name}...")
    
#     all_results = []
#     sample_images = {'inputs': [], 'predictions': [], 'targets': [], 'paths': []}
#     timing_data = []
    
#     # Warmup run
#     print(f"Performing warmup for {model_name}...")
#     with torch.no_grad():
#         dummy_input = torch.randn(1, 3, 112, 112).to(device)
#         for _ in range(3):
#             _ = model(dummy_input)
#         if device.type == 'cuda':
#             torch.cuda.synchronize()
    
#     model.eval()
#     with torch.no_grad():
#         for batch_idx, (inputs, targets, paths) in enumerate(tqdm(test_loader, desc=f"Testing {model_name}")):
#             inputs = inputs.to(device)
#             targets = targets.to(device)
#             batch_size = inputs.size(0)
            
#             # Time the forward pass
#             if device.type == 'cuda':
#                 torch.cuda.synchronize()
            
#             start_time = time.time()
#             predictions = model(inputs)
            
#             if device.type == 'cuda':
#                 torch.cuda.synchronize()
            
#             end_time = time.time()
#             batch_time = end_time - start_time
#             per_image_time = batch_time / batch_size
            
#             # Store timing data
#             for i in range(batch_size):
#                 timing_data.append(per_image_time)
            
#             # Calculate metrics for each image in batch
#             for i in range(inputs.size(0)):
#                 metrics = calculate_metrics(predictions[i:i+1], targets[i:i+1])
#                 metrics['image_path'] = paths[i]
#                 metrics['reconstruction_time_ms'] = per_image_time * 1000
#                 metrics['model'] = model_name
#                 all_results.append(metrics)
                
#                 # Save some sample images for visualization
#                 if len(sample_images['inputs']) < 32:
#                     sample_images['inputs'].append(inputs[i])
#                     sample_images['predictions'].append(predictions[i])
#                     sample_images['targets'].append(targets[i])
#                     sample_images['paths'].append(paths[i])
    
#     return all_results, sample_images, timing_data

# def create_comparative_visualization(all_sample_images, save_path, max_images=8):
#     """Create a comparative visualization showing all three models"""
#     n_images = min(len(all_sample_images['TransUNet']['inputs']), max_images)
    
#     fig, axes = plt.subplots(4, n_images, figsize=(3*n_images, 12))
#     if n_images == 1:
#         axes = axes.reshape(4, 1)
    
#     for i in range(n_images):
#         # Input (same for all models)
#         inp = all_sample_images['TransUNet']['inputs'][i].cpu().numpy().transpose(1, 2, 0)
#         axes[0, i].imshow(np.clip(inp, 0, 1))
#         axes[0, i].set_title(f"Input {i+1}\n(112×112)")
#         axes[0, i].axis('off')
        
#         # TransUNet prediction
#         pred_transunet = all_sample_images['TransUNet']['predictions'][i].cpu().numpy().transpose(1, 2, 0)
#         axes[1, i].imshow(np.clip(pred_transunet, 0, 1))
#         axes[1, i].set_title(f"TransUNet {i+1}")
#         axes[1, i].axis('off')
        
#         # ViT-UNet prediction
#         pred_vitunet = all_sample_images['ViT-UNet']['predictions'][i].cpu().numpy().transpose(1, 2, 0)
#         axes[2, i].imshow(np.clip(pred_vitunet, 0, 1))
#         axes[2, i].set_title(f"ViT-UNet {i+1}")
#         axes[2, i].axis('off')
        
#         # U-Net prediction
#         pred_unet = all_sample_images['U-Net']['predictions'][i].cpu().numpy().transpose(1, 2, 0)
#         axes[3, i].imshow(np.clip(pred_unet, 0, 1))
#         axes[3, i].set_title(f"U-Net {i+1}")
#         axes[3, i].axis('off')
    
#     plt.suptitle("Model Comparison: Input → TransUNet → ViT-UNet → U-Net", fontsize=16)
#     plt.tight_layout()
#     plt.savefig(save_path, dpi=150, bbox_inches='tight')
#     plt.close()

# def create_metrics_comparison(combined_df, save_dir):
#     """Create comprehensive comparison plots"""
    
#     # 1. Performance comparison boxplots
#     fig, axes = plt.subplots(2, 3, figsize=(18, 12))
#     metrics = ['mse', 'l1', 'psnr', 'ssim', 'ms_ssim', 'reconstruction_time_ms']
#     metric_labels = ['MSE', 'L1 Loss', 'PSNR (dB)', 'SSIM', 'MS-SSIM', 'Time (ms)']
    
#     for i, (metric, label) in enumerate(zip(metrics, metric_labels)):
#         row = i // 3
#         col = i % 3
        
#         combined_df.boxplot(column=metric, by='model', ax=axes[row, col])
#         axes[row, col].set_title(f'{label} Comparison')
#         axes[row, col].set_xlabel('Model')
#         axes[row, col].set_ylabel(label)
        
#     plt.suptitle('Performance Metrics Comparison Across Models')
#     plt.tight_layout()
#     plt.savefig(os.path.join(save_dir, 'metrics_comparison_boxplots.png'), dpi=150, bbox_inches='tight')
#     plt.close()
    
#     # 2. Performance vs Speed scatter plot
#     plt.figure(figsize=(12, 8))
#     models = combined_df['model'].unique()
#     colors = ['blue', 'red', 'green']
    
#     for i, model in enumerate(models):
#         model_data = combined_df[combined_df['model'] == model]
#         plt.scatter(model_data['reconstruction_time_ms'], model_data['ssim'], 
#                    alpha=0.6, label=model, color=colors[i], s=50)
        
#         # Add mean point
#         mean_time = model_data['reconstruction_time_ms'].mean()
#         mean_ssim = model_data['ssim'].mean()
#         plt.scatter(mean_time, mean_ssim, color=colors[i], s=200, marker='*', 
#                    edgecolors='black', linewidth=2)
    
#     plt.xlabel('Reconstruction Time (ms)')
#     plt.ylabel('SSIM')
#     plt.title('Quality vs Speed Trade-off\n(Stars indicate model averages)')
#     plt.legend()
#     plt.grid(True, alpha=0.3)
#     plt.savefig(os.path.join(save_dir, 'quality_vs_speed.png'), dpi=150, bbox_inches='tight')
#     plt.close()
    
#     # 3. Summary statistics table
#     summary_stats = combined_df.groupby('model').agg({
#         'mse': ['mean', 'std'],
#         'l1': ['mean', 'std'],
#         'psnr': ['mean', 'std'],
#         'ssim': ['mean', 'std'],
#         'ms_ssim': ['mean', 'std'],
#         'reconstruction_time_ms': ['mean', 'std']
#     }).round(4)
    
#     summary_stats.to_csv(os.path.join(save_dir, 'model_comparison_summary.csv'))
    
#     return summary_stats

# def main():
#     # Configuration
#     config = {
#         'models': {
#             'TransUNet': {
#                 'model_path': r"D:\JHU\ImageNet\transunet_8x8_checkpoints\epoch_20.pth",
#                 'params': {
#                     'img_size': 224,
#                     'patch_size': 16,
#                     'embed_dim': 768,
#                     'depth': 12,
#                     'num_heads': 12,
#                     'mlp_ratio': 4,
#                     'dropout': 0.1
#                 }
#             },
#             'ViT-UNet': {
#                 'model_path': r"D:\JHU\ImageNet\trust_8x8_checkpoints\epoch_80.pth",
#                 'params': {
#                     'pretrained_model_name': "google/vit-base-patch16-224",
#                     'output_size': (224, 224)
#                 }
#             },
#             'U-Net': {
#                 'model_path': r"D:\JHU\ImageNet\unet_8x8_checkpoints\epoch_10.pth",
#                 'params': {
#                     'n_channels': 3,
#                     'n_classes': 3,
#                     'bilinear': True
#                 }
#             }
#         },
#         'test_dir': r"F:\imgnet\data\test",
#         'results_dir': "./new_test",
#         'seed': 42,
#         'batch_size': 8,  # Conservative batch size to accommodate all models
#         'max_test_images': 500,
#         'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
#     }
    
#     print("="*80)
#     print("COMPREHENSIVE MODEL COMPARISON")
#     print("="*80)
#     print(f"Device: {config['device']}")
#     print(f"Test directory: {config['test_dir']}")
#     print(f"Max test images: {config['max_test_images']}")
#     print(f"Batch size: {config['batch_size']}")
#     print("Models to evaluate:", list(config['models'].keys()))
    
#     # Create results directory
#     os.makedirs(config['results_dir'], exist_ok=True)
    
#     # Load dataset
#     test_dataset = PatchwiseOrthonormalDataset(
#         data_dir=config['test_dir'], 
#         seed=config['seed'], 
#         verbose=True
#     )
    
#     # Limit dataset size if specified
#     if config['max_test_images'] and len(test_dataset) > config['max_test_images']:
#         indices = np.random.choice(len(test_dataset), config['max_test_images'], replace=False)
#         test_dataset.image_files = [test_dataset.image_files[i] for i in sorted(indices)]
#         print(f"Limited test set to {config['max_test_images']} images")
    
#     test_loader = DataLoader(
#         test_dataset, 
#         batch_size=config['batch_size'], 
#         shuffle=False,
#         num_workers=4, 
#         pin_memory=True
#     )
    
#     device = torch.device(config['device'])
    
#     # Load and evaluate each model
#     all_results = []
#     all_sample_images = {}
#     all_timing_data = {}
#     model_info = {}
    
#     for model_name, model_config in config['models'].items():
#         print(f"\n{'-'*50}")
#         print(f"Loading {model_name}...")
        
#         # Load model
#         if model_name == 'TransUNet':
#             model = TransUNet(**model_config['params']).to(device)
#         elif model_name == 'ViT-UNet':
#             model = ViTUNetForInverseProblem(**model_config['params']).to(device)
#         elif model_name == 'U-Net':
#             model = UNetForInverseProblem(**model_config['params']).to(device)
        
#         # Load checkpoint
#         if not os.path.exists(model_config['model_path']):
#             print(f"Warning: Model checkpoint not found: {model_config['model_path']}")
#             continue
            
#         checkpoint = torch.load(model_config['model_path'], map_location=device)
#         if 'model_state_dict' in checkpoint:
#             model.load_state_dict(checkpoint['model_state_dict'])
#             print(f"Loaded {model_name} from epoch {checkpoint.get('epoch', 'unknown')}")
#         else:
#             model.load_state_dict(checkpoint)
        
#         # Model info
#         total_params = sum(p.numel() for p in model.parameters())
#         trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
#         model_info[model_name] = {
#             'total_parameters': total_params,
#             'trainable_parameters': trainable_params,
#             'config': model_config['params']
#         }
#         print(f"{model_name} - Total parameters: {total_params:,}")
        
#         # Evaluate model
#         results, sample_images, timing_data = evaluate_model(model, test_loader, device, model_name)
        
#         all_results.extend(results)
#         all_sample_images[model_name] = sample_images
#         all_timing_data[model_name] = timing_data
        
#         # Clear GPU memory
#         del model
#         if device.type == 'cuda':
#             torch.cuda.empty_cache()
    
#     # Combine all results
#     combined_df = pd.DataFrame(all_results)
    
#     # Calculate summary statistics for each model
#     print("\n" + "="*80)
#     print("COMPARISON RESULTS SUMMARY")
#     print("="*80)
    
#     summary_stats = {}
#     for model_name in config['models'].keys():
#         if model_name not in combined_df['model'].values:
#             continue
            
#         model_data = combined_df[combined_df['model'] == model_name]
#         timing_stats = {
#             'mean_time_ms': np.mean(all_timing_data[model_name]) * 1000,
#             'std_time_ms': np.std(all_timing_data[model_name]) * 1000,
#             'mean_fps': 1.0 / np.mean(all_timing_data[model_name])
#         }
        
#         summary_stats[model_name] = {
#             'total_images': len(model_data),
#             'mean_mse': model_data['mse'].mean(),
#             'mean_l1': model_data['l1'].mean(),
#             'mean_psnr': model_data['psnr'].mean(),
#             'std_psnr': model_data['psnr'].std(),
#             'mean_ssim': model_data['ssim'].mean(),
#             'std_ssim': model_data['ssim'].std(),
#             'mean_ms_ssim': model_data['ms_ssim'].mean(),
#             'timing_stats': timing_stats,
#             'model_info': model_info[model_name]
#         }
        
#         print(f"\n{model_name}:")
#         print(f"  Parameters: {model_info[model_name]['total_parameters']:,}")
#         print(f"  PSNR: {summary_stats[model_name]['mean_psnr']:.2f} ± {summary_stats[model_name]['std_psnr']:.2f} dB")
#         print(f"  SSIM: {summary_stats[model_name]['mean_ssim']:.4f} ± {summary_stats[model_name]['std_ssim']:.4f}")
#         print(f"  MS-SSIM: {summary_stats[model_name]['mean_ms_ssim']:.4f}")
#         print(f"  Speed: {timing_stats['mean_time_ms']:.2f} ± {timing_stats['std_time_ms']:.2f} ms/image")
#         print(f"  Throughput: {timing_stats['mean_fps']:.2f} images/second")
    
#     # Save results
#     combined_df.to_csv(os.path.join(config['results_dir'], 'combined_results.csv'), index=False)
    
#     with open(os.path.join(config['results_dir'], 'comparison_summary.json'), 'w') as f:
#         json.dump(summary_stats, f, indent=2)
    
#     # Create visualizations
#     print("\nCreating comparative visualizations...")
    
#     # 1. Comparative model outputs
#     create_comparative_visualization(
#         all_sample_images,
#         os.path.join(config['results_dir'], 'model_comparison_grid.png'),
#         max_images=8
#     )
    
#     # 2. Metrics comparison
#     comparison_summary = create_metrics_comparison(combined_df, config['results_dir'])
    
#     print(f"\nAll comparison results saved to: {config['results_dir']}")
#     print("Generated files:")
#     print("- combined_results.csv: All individual results")
#     print("- comparison_summary.json: Summary statistics")
#     print("- model_comparison_summary.csv: Statistical summary table")
#     print("- model_comparison_grid.png: Side-by-side visual comparison")
#     print("- metrics_comparison_boxplots.png: Performance metrics comparison")
#     print("- quality_vs_speed.png: Quality vs speed trade-off analysis")
    
#     # Print ranking
#     print("\n" + "="*50)
#     print("MODEL RANKING")
#     print("="*50)
#     print("By SSIM (Quality):")
#     ssim_ranking = combined_df.groupby('model')['ssim'].mean().sort_values(ascending=False)
#     for i, (model, ssim_val) in enumerate(ssim_ranking.items(), 1):
#         print(f"  {i}. {model}: {ssim_val:.4f}")
    
#     print("\nBy Speed (ms/image):")
#     speed_ranking = combined_df.groupby('model')['reconstruction_time_ms'].mean().sort_values()
#     for i, (model, time_val) in enumerate(speed_ranking.items(), 1):
#         print(f"  {i}. {model}: {time_val:.2f} ms")
    
#     print("\nBy Parameter Efficiency (SSIM/Million Parameters):")
#     efficiency_scores = {}
#     for model in ssim_ranking.index:
#         ssim_val = ssim_ranking[model]
#         params = model_info[model]['total_parameters'] / 1e6  # Convert to millions
#         efficiency_scores[model] = ssim_val / params
    
#     efficiency_ranking = pd.Series(efficiency_scores).sort_values(ascending=False)
#     for i, (model, eff_val) in enumerate(efficiency_ranking.items(), 1):
#         print(f"  {i}. {model}: {eff_val:.6f}")
    
#     return combined_df, summary_stats

# if __name__ == "__main__":
#     # Make sure required packages are installed:
#     # pip install pytorch-msssim transformers
#     combined_df, summary_stats = main()




#!/usr/bin/env python3
import os
# Fix OpenMP library conflict
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import json
import time
from tqdm import tqdm
import seaborn as sns
from pytorch_msssim import ssim, ms_ssim
import pandas as pd
from transformers import ViTModel, ViTConfig
from collections import defaultdict
import csv

# ========== SHARED DATASET CLASS ==========
def make_gaussian_random_orthonormal_rows(h=64, w=64, seed=42):
    """Generate a matrix A of size [h, w] where rows are orthonormal."""
    if seed is not None:
        torch.manual_seed(seed)
    A = torch.randn(h, w)
    Q, R = torch.linalg.qr(A.T)
    return Q.T

class PatchwiseOrthonormalDataset:
    """Dataset that applies patch-wise orthonormal transformation to images."""
    def __init__(self, data_dir, seed=42, verbose=False):
        self.data_dir = data_dir
        self.A = make_gaussian_random_orthonormal_rows(h=64, w=64, seed=seed)
        
        self.data_path = Path(data_dir)
        if not self.data_path.exists():
            raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPEG', '.JPG'}
        self.image_files = [f for f in self.data_path.iterdir() 
                           if f.is_file() and f.suffix in image_extensions]
        
        if len(self.image_files) == 0:
            raise ValueError(f"No images found in {data_dir}")
        
        if verbose:
            print(f"Loaded {len(self.image_files)} images from {data_dir}")

    def __len__(self):
        return len(self.image_files)

    def resize_min_side(self, img, min_side=224):
        w, h = img.size
        s = min_side / min(w, h)
        return img.resize((int(round(w*s)), int(round(h*s))), Image.Resampling.LANCZOS)

    def center_crop(self, img, size=224):
        w, h = img.size
        left = (w - size) // 2
        top = (h - size) // 2
        return img.crop((left, top, left + size, top + size))

    def preprocess_image(self, img):
        img = img.convert("RGB")
        img_resized = self.resize_min_side(img, 224)
        img_crop = self.center_crop(img_resized, 224)
        x = np.array(img_crop).astype(np.float32) / 255.0
        return x

    def process_image_with_orthonormal_masks(self, np_img, mask_matrix):
        img_tensor = torch.from_numpy(np_img).float()
        
        if img_tensor.shape[2] == 3:
            img_gray = img_tensor.mean(dim=2)
        else:
            img_gray = img_tensor
        
        patches = img_gray.unfold(0, 8, 8).unfold(1, 8, 8)
        transformed_patches = torch.zeros(28, 28, 16)
        
        for i in range(28):
            for j in range(28):
                patch_flat = patches[i, j].flatten()
                transformed = mask_matrix @ patch_flat
                transformed = transformed[torch.randperm(transformed.shape[0])[:transformed.shape[0] // 4]]
                transformed_patches[i, j] = transformed
        
        return transformed_patches

    def reconstruct_masked_image(self, transformed_patches):
        masked_image = torch.zeros(112, 112)
        
        for i in range(14):
            for j in range(14):
                transformed_patch = transformed_patches[i, j]
                patch_4x4 = transformed_patch.reshape(4, 4)
                patch_8x8 = F.interpolate(patch_4x4.unsqueeze(0).unsqueeze(0), 
                                        size=(8, 8), mode='bilinear', align_corners=True)[0, 0]
                
                start_h = i * 8
                end_h = start_h + 8
                start_w = j * 8
                end_w = start_w + 8
                
                masked_image[start_h:end_h, start_w:end_w] = patch_8x8
        
        return masked_image

    def apply_patchwise_orthonormal_transform(self, x):
        y_channels = []
        
        for c in range(3):
            single_channel = x[..., c]
            transformed_patches = self.process_image_with_orthonormal_masks(
                np.expand_dims(single_channel, axis=2), self.A
            )
            masked_channel = self.reconstruct_masked_image(transformed_patches)
            y_channels.append(masked_channel.numpy())
        
        y = np.stack(y_channels, axis=2)
        y_min = y.min()
        y_max = y.max()
        y_norm = (y - y_min) / (y_max - y_min + 1e-8)
        
        return y_norm

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        
        x = self.preprocess_image(img)
        y = self.apply_patchwise_orthonormal_transform(x)
        
        x_tensor = torch.from_numpy(x).permute(2, 0, 1)
        y_tensor = torch.from_numpy(y).permute(2, 0, 1)
        
        return y_tensor, x_tensor, str(img_path)

# ========== TRANSUNET MODEL ==========
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        x = self.proj(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        return x, (H, W)

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformerEncoder(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, 
                 depth=12, num_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.n_patches, embed_dim) * 0.02)
        self.dropout = nn.Dropout(dropout)
        
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        x, (H, W) = self.patch_embed(x)
        
        x = x + self.pos_embed
        x = self.dropout(x)
        
        features = []
        for i, block in enumerate(self.blocks):
            x = block(x)
            if i in [2, 5, 8]:
                features.append(x)
        
        x = self.norm(x)
        features.append(x)
        
        return features, (H, W)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.conv(x)

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels=0):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels // 2 + skip_channels, out_channels)
        
    def forward(self, x, skip=None):
        x = self.up(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class TransUNet(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, out_channels=3,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()
        
        self.input_prep = nn.Sequential(
            nn.Upsample(size=(img_size, img_size), mode='bilinear', align_corners=True),
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        self.vit_encoder = VisionTransformerEncoder(
            img_size=img_size, patch_size=patch_size, in_channels=64,
            embed_dim=embed_dim, depth=depth, num_heads=num_heads,
            mlp_ratio=mlp_ratio, dropout=dropout
        )
        
        self.cnn_enc1 = ConvBlock(64, 64)
        self.cnn_enc2 = ConvBlock(64, 128)
        self.cnn_enc3 = ConvBlock(128, 256)
        self.cnn_enc4 = ConvBlock(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        
        self.vit_to_cnn = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 512)
        )
        
        self.dec4 = UpBlock(1024, 256, skip_channels=256)
        self.dec3 = UpBlock(256, 128, skip_channels=128)
        self.dec2 = UpBlock(128, 64, skip_channels=64)
        self.dec1 = UpBlock(64, 64, skip_channels=64)
        
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, kernel_size=1),
            nn.Sigmoid()
        )
        
        self.skip_connection = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, kernel_size=3, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        original_input = x
        
        x = self.input_prep(x)
        
        skip_features = self.skip_connection(original_input)
        skip_features = F.interpolate(skip_features, size=(224, 224), 
                                    mode='bilinear', align_corners=True)
        
        enc1 = self.cnn_enc1(x)
        enc2 = self.cnn_enc2(self.pool(enc1))
        enc3 = self.cnn_enc3(self.pool(enc2))
        enc4 = self.cnn_enc4(self.pool(enc3))
        
        vit_features, (H, W) = self.vit_encoder(x)
        
        final_vit_features = vit_features[-1]
        B, N, C = final_vit_features.shape
        
        vit_proj = self.vit_to_cnn(final_vit_features)
        
        vit_spatial = vit_proj.transpose(1, 2).reshape(B, 512, H, W)
        
        vit_spatial = F.interpolate(vit_spatial, size=enc4.shape[-2:], 
                                  mode='bilinear', align_corners=True)
        
        bottleneck = torch.cat([vit_spatial, enc4], dim=1)
        
        dec4_up = self.dec4.up(bottleneck)
        dec4_concat = torch.cat([dec4_up, enc3], dim=1)
        dec4_out = self.dec4.conv(dec4_concat)
        
        dec3_up = self.dec3.up(dec4_out)
        dec3_concat = torch.cat([dec3_up, enc2], dim=1)
        dec3_out = self.dec3.conv(dec3_concat)
        
        dec2_up = self.dec2.up(dec3_out)
        dec2_concat = torch.cat([dec2_up, enc1], dim=1)
        dec2_out = self.dec2.conv(dec2_concat)
        
        dec1_up = self.dec1.up(dec2_out)
        dec1_up = F.interpolate(dec1_up, size=(224, 224), mode='bilinear', align_corners=True)
        dec1_concat = torch.cat([dec1_up, enc1], dim=1)
        dec1_out = self.dec1.conv(dec1_concat)
        
        output = self.final_conv(dec1_out)
        
        output = output + skip_features
        
        output = torch.clamp(output, 0, 1)
        
        return output

# ========== VIT-UNET MODEL ==========
class ViTUNetForInverseProblem(nn.Module):
    def __init__(self, pretrained_model_name="google/vit-base-patch16-224", output_size=(224, 224)):
        super().__init__()
    
        cfg = ViTConfig.from_pretrained(pretrained_model_name)
        cfg.add_pooling_layer = False
        self.vit = ViTModel.from_pretrained(pretrained_model_name, config=cfg, ignore_mismatched_sizes=True)
    
        self.output_size = output_size
        self.hidden_dim = 768

        self.input_upsample = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.BatchNorm2d(16), nn.ReLU(True),
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
            nn.Conv2d(16, 3, 3, padding=1),
            nn.Sigmoid()
        )

        self.skip_upsample = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 64, 3, padding=1), 
            nn.BatchNorm2d(64), nn.ReLU(True),
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Tanh()
        )

        self.adaptive_pool1 = nn.AdaptiveAvgPool2d((28, 28))
        self.adaptive_pool2 = nn.AdaptiveAvgPool2d((14, 14))
        self.adaptive_pool3 = nn.AdaptiveAvgPool2d((7, 7))
        self.adaptive_pool_final = nn.AdaptiveAvgPool2d((7, 7))

        self.skip_conn1 = nn.Conv2d(self.hidden_dim, 128, kernel_size=1)
        self.skip_conn2 = nn.Conv2d(self.hidden_dim, 256, kernel_size=1)
        self.skip_conn3 = nn.Conv2d(self.hidden_dim, 512, kernel_size=1)

        self.up1 = nn.Sequential(
            nn.Upsample(size=(14, 14), mode='bilinear', align_corners=True),
            nn.Conv2d(self.hidden_dim, 512, 3, padding=1), 
            nn.BatchNorm2d(512), nn.ReLU(True)
        )
        self.up2 = nn.Sequential(
            nn.Upsample(size=(28, 28), mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, 3, padding=1), 
            nn.BatchNorm2d(256), nn.ReLU(True)
        )
        self.up3 = nn.Sequential(
            nn.Upsample(size=(56, 56), mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, 3, padding=1), 
            nn.BatchNorm2d(128), nn.ReLU(True)
        )
        self.up4 = nn.Sequential(
            nn.Upsample(size=(112, 112), mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, 3, padding=1), 
            nn.BatchNorm2d(64), nn.ReLU(True)
        )
    
        self.final = nn.Sequential(
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Sigmoid()
        )

        self.fusion = nn.Sequential(
            nn.Conv2d(6, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Sigmoid()
        )

    def _extract(self, x3):
        out = self.vit(x3, output_hidden_states=True)
        early = out.hidden_states[3]
        mid   = out.hidden_states[6]
        late  = out.hidden_states[9]
        last  = out.last_hidden_state
        return early, mid, late, last

    def _to_spatial(self, tokens):
        B, N, C = tokens.shape
        HW = int((N - 1) ** 0.5)
        t = tokens[:, 1:, :].permute(0, 2, 1)
        return t.reshape(B, C, HW, HW)

    def forward(self, x):
        skip_features = self.skip_upsample(x)
        
        if x.shape[-2:] != (224, 224):
            x_upsampled = self.input_upsample(x)
        else:
            x_upsampled = x

        e, m, l, f = self._extract(x_upsampled)
        e, m, l, f = map(self._to_spatial, (e, m, l, f))

        skip1 = self.adaptive_pool1(e)
        skip2 = self.adaptive_pool2(m)
        skip3 = self.adaptive_pool3(l)
        x_feat = self.adaptive_pool_final(f)

        x_feat = self.up1(x_feat)
        x_feat = x_feat + F.interpolate(self.skip_conn3(skip3), size=(14, 14), mode='bilinear', align_corners=True)
    
        x_feat = self.up2(x_feat)
        x_feat = x_feat + F.interpolate(self.skip_conn2(skip2), size=(28, 28), mode='bilinear', align_corners=True)
    
        x_feat = self.up3(x_feat)
        x_feat = x_feat + F.interpolate(self.skip_conn1(skip1), size=(56, 56), mode='bilinear', align_corners=True)
    
        x_feat = self.up4(x_feat)
        vit_output = self.final(x_feat)
        
        combined = torch.cat([vit_output, skip_features], dim=1)
        out = self.fusion(combined)
    
        if self.output_size != (224, 224):
            out = F.interpolate(out, size=self.output_size, mode='bilinear', align_corners=True)
        
        return out

# ========== RESTORMER MODEL ==========
class ChannelAttention(nn.Module):
    """Simplified attention mechanism using standard conv operations"""
    def __init__(self, dim, num_heads, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        
        self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        b, c, h, w = x.shape
        
        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)
        
        q = q.view(b, self.num_heads, -1, h * w)
        k = k.view(b, self.num_heads, -1, h * w)
        v = v.view(b, self.num_heads, -1, h * w)

        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        out = out.view(b, -1, h, w)

        out = self.project_out(out)
        return out

class FeedForward(nn.Module):
    """Gated feed-forward network"""
    def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
        super().__init__()
        
        hidden_features = int(dim * ffn_expansion_factor)
        
        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, groups=hidden_features * 2, bias=bias)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

class RestormerTransformerBlock(nn.Module):
    """Restormer transformer block using GroupNorm"""
    def __init__(self, dim, num_heads, ffn_expansion_factor=2.66, bias=False):
        super().__init__()
        
        self.norm1 = nn.GroupNorm(num_groups=1, num_channels=dim, eps=1e-6, affine=True)
        self.attn = ChannelAttention(dim, num_heads, bias)
        self.norm2 = nn.GroupNorm(num_groups=1, num_channels=dim, eps=1e-6, affine=True)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

class OverlapPatchEmbed(nn.Module):
    """Overlapped image patch embedding with 3x3 Conv"""
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super().__init__()
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)
        return x

class Downsample(nn.Module):
    """Downsampling module"""
    def __init__(self, n_feat):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelUnshuffle(2)
        )

    def forward(self, x):
        return self.body(x)

class Upsample(nn.Module):
    """Upsampling module"""
    def __init__(self, n_feat):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2)
        )

    def forward(self, x):
        return self.body(x)

class RestormerForInverseProblem(nn.Module):
    """
    Restormer model for orthonormal inverse problem:
    Input: 112x112x3 (degraded/masked image)  
    Output: 224x224x3 (reconstructed original image)
    """
    def __init__(self, 
                 inp_channels=3, 
                 out_channels=3, 
                 dim=48,
                 num_blocks=[4, 6, 6, 8], 
                 num_heads=[1, 2, 4, 8],
                 ffn_expansion_factor=2.66,
                 bias=False):
        
        super().__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

        # Input upsampling to handle 112x112 → 224x224
        self.input_upsample = nn.Sequential(
            nn.Conv2d(inp_channels, dim//2, kernel_size=3, padding=1, bias=bias),
            nn.GELU(),
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
            nn.Conv2d(dim//2, inp_channels, kernel_size=3, padding=1, bias=bias),
        )

        # Skip connection processing
        self.skip_conv = nn.Sequential(
            nn.Conv2d(inp_channels, dim, kernel_size=3, padding=1, bias=bias),
            nn.GELU(),
            nn.Conv2d(dim, out_channels, kernel_size=3, padding=1, bias=bias),
        )

        # Encoder
        self.encoder_level1 = nn.Sequential(*[RestormerTransformerBlock(dim=dim, num_heads=num_heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in range(num_blocks[0])])
        
        self.down1_2 = Downsample(dim)
        self.encoder_level2 = nn.Sequential(*[RestormerTransformerBlock(dim=int(dim*2**1), num_heads=num_heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in range(num_blocks[1])])
        
        self.down2_3 = Downsample(int(dim*2**1))
        self.encoder_level3 = nn.Sequential(*[RestormerTransformerBlock(dim=int(dim*2**2), num_heads=num_heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2))
        self.latent = nn.Sequential(*[RestormerTransformerBlock(dim=int(dim*2**3), num_heads=num_heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in range(num_blocks[3])])
        
        # Decoder
        self.up4_3 = Upsample(int(dim*2**3))
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[RestormerTransformerBlock(dim=int(dim*2**2), num_heads=num_heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in range(num_blocks[2])])

        self.up3_2 = Upsample(int(dim*2**2))
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[RestormerTransformerBlock(dim=int(dim*2**1), num_heads=num_heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in range(num_blocks[1])])
        
        self.up2_1 = Upsample(int(dim*2**1))
        self.reduce_chan_level1 = nn.Conv2d(int(dim*2**1), dim, kernel_size=1, bias=bias)
        self.decoder_level1 = nn.Sequential(*[RestormerTransformerBlock(dim=dim, num_heads=num_heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in range(num_blocks[0])])
        
        # Final refinement
        self.refinement = nn.Sequential(*[RestormerTransformerBlock(dim=dim, num_heads=num_heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in range(num_blocks[0])])

        # Output projection
        self.output = nn.Conv2d(dim, out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

        # Final activation
        self.final_activation = nn.Sigmoid()

    def forward(self, inp_img):
        # inp_img is [B, 3, 112, 112]
        
        # Create skip connection
        inp_upsampled = self.input_upsample(inp_img)  # [B, 3, 224, 224]
        skip_features = self.skip_conv(inp_upsampled)  # [B, 3, 224, 224]
        
        # Process upsampled input through Restormer
        inp_enc_level1 = self.patch_embed(inp_upsampled)  # [B, 48, 224, 224]
        out_enc_level1 = self.encoder_level1(inp_enc_level1)  # [B, 48, 224, 224]
        
        inp_enc_level2 = self.down1_2(out_enc_level1)  # [B, 96, 112, 112]
        out_enc_level2 = self.encoder_level2(inp_enc_level2)  # [B, 96, 112, 112]

        inp_enc_level3 = self.down2_3(out_enc_level2)  # [B, 192, 56, 56]
        out_enc_level3 = self.encoder_level3(inp_enc_level3)  # [B, 192, 56, 56]

        inp_enc_level4 = self.down3_4(out_enc_level3)  # [B, 384, 28, 28]        
        latent = self.latent(inp_enc_level4)  # [B, 384, 28, 28]
                        
        inp_dec_level3 = self.up4_3(latent)  # [B, 192, 56, 56]
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)  # [B, 384, 56, 56]
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)  # [B, 192, 56, 56]
        out_dec_level3 = self.decoder_level3(inp_dec_level3)  # [B, 192, 56, 56]

        inp_dec_level2 = self.up3_2(out_dec_level3)  # [B, 96, 112, 112]
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)  # [B, 192, 112, 112]
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)  # [B, 96, 112, 112]
        out_dec_level2 = self.decoder_level2(inp_dec_level2)  # [B, 96, 112, 112]

        inp_dec_level1 = self.up2_1(out_dec_level2)  # [B, 48, 224, 224]
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)  # [B, 96, 224, 224]
        inp_dec_level1 = self.reduce_chan_level1(inp_dec_level1)  # [B, 48, 224, 224]
        out_dec_level1 = self.decoder_level1(inp_dec_level1)  # [B, 48, 224, 224]
        
        # Final refinement
        out_dec_level1 = self.refinement(out_dec_level1)  # [B, 48, 224, 224]
        
        # Output projection
        output = self.output(out_dec_level1)  # [B, 3, 224, 224]
        
        # Add skip connection
        output = output + skip_features  # [B, 3, 224, 224]
        
        # Final activation
        output = self.final_activation(output)  # [B, 3, 224, 224]
        
        return output

# ========== TRADITIONAL U-NET MODEL ==========
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                       diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNetForInverseProblem(nn.Module):
    def __init__(self, n_channels=3, n_classes=3, bilinear=False):
        super(UNetForInverseProblem, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.input_prep = nn.Sequential(
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
            nn.Conv2d(n_channels, n_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(n_channels),
            nn.ReLU(inplace=True)
        )

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        
        self.outc = OutConv(64, n_classes)
        
        self.final_activation = nn.Sigmoid()

        self.skip_connection = nn.Sequential(
            nn.Conv2d(n_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, n_classes, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        original_input = x
        
        x = self.input_prep(x)
        
        skip_features = self.skip_connection(x)
        
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        logits = self.outc(x)
        
        output = logits + skip_features
        
        output = self.final_activation(output)
        
        return output

# ========== EVALUATION FUNCTIONS ==========
def calculate_psnr(pred, target):
    """Calculate PSNR between predicted and target images"""
    mse = torch.mean((pred - target) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def calculate_mae(pred, target):
    """Calculate Mean Absolute Error"""
    return torch.mean(torch.abs(pred - target)).item()

def calculate_rmse(pred, target):
    """Calculate Root Mean Square Error"""
    return torch.sqrt(torch.mean((pred - target) ** 2)).item()

def calculate_fpr_score(pred, target, t_high=0.5, t_low=0.2):
    """
    Compute False Positive Regions (FPR) hallucination score
    
    Parameters
    ----------
    pred : torch.Tensor
        Generated or reconstructed image
    target : torch.Tensor
        Ground truth image
    t_high : float
        Upper threshold for generated image
    t_low : float
        Lower threshold for ground truth
        
    Returns
    -------
    score : float
        Hallucination score (fraction of hallucinated pixels)
    """
    # Convert to numpy and ensure single channel for FPR calculation
    pred_np = pred.cpu().numpy()
    target_np = target.cpu().numpy()
    
    # Take first channel if RGB
    if len(pred_np.shape) == 3:
        pred_np = pred_np[0]  # Take first channel
        target_np = target_np[0]
    
    # Ensure both inputs are normalized to [0,1] range
    pred_norm = (pred_np - pred_np.min()) / (pred_np.max() - pred_np.min() + 1e-8)
    target_norm = (target_np - target_np.min()) / (target_np.max() - target_np.min() + 1e-8)
    
    # Define hallucination mask H = (pred > t_high) ∧ (target ≤ t_low)
    H = np.logical_and(pred_norm > t_high, target_norm <= t_low)
    
    # Compute hallucination score as fraction of hallucinated pixels
    hallucination_score = np.sum(H) / H.size
    
    return hallucination_score

def calculate_metrics(pred, target):
    """Calculate comprehensive metrics with mean ± std format"""
    pred_clamped = torch.clamp(pred, 0, 1)
    target_clamped = torch.clamp(target, 0, 1)
    
    # Basic metrics
    mse = F.mse_loss(pred, target).item()
    mae = calculate_mae(pred, target)
    rmse = calculate_rmse(pred, target)
    
    # PSNR
    psnr = calculate_psnr(pred_clamped, target_clamped).item()
    
    # SSIM metrics
    ssim_val = ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True).item()
    ms_ssim_val = ms_ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True).item()
    
    # FPR score for hallucination detection
    fpr_score = calculate_fpr_score(pred, target)
    
    return {
        'mse': mse,
        'mae': mae,
        'rmse': rmse,
        'psnr': psnr,
        'ssim': ssim_val,
        'ms_ssim': ms_ssim_val,
        'fpr': fpr_score
    }

def calculate_statistics(raw_metrics):
    """Calculate mean ± std for all metrics"""
    stats = {}
    for metric_name, values in raw_metrics.items():
        values = np.array(values)
        stats[f'{metric_name}_mean'] = float(np.mean(values))
        stats[f'{metric_name}_std'] = float(np.std(values, ddof=1) if len(values) > 1 else 0)
        stats[f'{metric_name}_min'] = float(np.min(values))
        stats[f'{metric_name}_max'] = float(np.max(values))
        stats[f'{metric_name}_median'] = float(np.median(values))
    return stats

def evaluate_model(model, test_loader, device, model_name):
    """Evaluate a single model and return results with mean ± std format"""
    print(f"Evaluating {model_name}...")
    
    # Initialize metrics storage
    raw_metrics = defaultdict(list)
    sample_images = {'inputs': [], 'predictions': [], 'targets': [], 'paths': []}
    
    # Warmup run
    print(f"Performing warmup for {model_name}...")
    with torch.no_grad():
        dummy_input = torch.randn(1, 3, 112, 112).to(device)
        for _ in range(3):
            _ = model(dummy_input)
        if device.type == 'cuda':
            torch.cuda.synchronize()
    
    model.eval()
    start_time = time.time()
    
    with torch.no_grad():
        for batch_idx, (inputs, targets, paths) in enumerate(tqdm(test_loader, desc=f"Testing {model_name}")):
            inputs = inputs.to(device)
            targets = targets.to(device)
            batch_size = inputs.size(0)
            
            # Time the forward pass
            if device.type == 'cuda':
                torch.cuda.synchronize()
            
            inference_start = time.time()
            predictions = model(inputs)
            
            if device.type == 'cuda':
                torch.cuda.synchronize()
            
            inference_time = time.time() - inference_start
            per_image_inference_time = inference_time / batch_size
            
            # Calculate metrics for each image in batch
            for i in range(inputs.size(0)):
                metrics = calculate_metrics(predictions[i:i+1], targets[i:i+1])
                metrics['inference_time'] = per_image_inference_time
                metrics['reconstruction_time'] = per_image_inference_time  # Same as inference for simplicity
                metrics['image_path'] = paths[i]
                metrics['model'] = model_name
                
                # Store individual metric values for statistics calculation
                for key, value in metrics.items():
                    if key not in ['image_path', 'model']:
                        raw_metrics[key].append(value)
                
                # Save some sample images for visualization
                if len(sample_images['inputs']) < 32:
                    sample_images['inputs'].append(inputs[i])
                    sample_images['predictions'].append(predictions[i])
                    sample_images['targets'].append(targets[i])
                    sample_images['paths'].append(paths[i])
    
    eval_time = time.time() - start_time
    
    # Calculate statistics (mean ± std)
    stats = calculate_statistics(raw_metrics)
    stats.update({
        'model_type': model_name,
        'total_samples': len(raw_metrics['psnr']),
        'evaluation_time': eval_time,
        'model_parameters': sum(p.numel() for p in model.parameters())
    })
    
    return stats, sample_images, raw_metrics

def print_model_results(results, model_name):
    """Print evaluation results in a formatted way with mean ± std"""
    print(f"\n{'='*60}")
    print(f"{model_name.upper()} EVALUATION RESULTS")
    print(f"{'='*60}")
    print(f"Samples: {results['total_samples']}")
    print(f"Parameters: {results['model_parameters']:,}")
    print(f"Total time: {results['evaluation_time']:.2f}s")
    print()
    
    # Print metrics with mean ± std format
    metrics_to_show = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'rmse', 'fpr', 'inference_time', 'reconstruction_time']
    for metric in metrics_to_show:
        mean_key = f'{metric}_mean'
        std_key = f'{metric}_std'
        if mean_key in results and std_key in results:
            if 'time' in metric:
                unit = ' seconds'
                print(f"{metric.upper():16}: {results[mean_key]:.6f} ± {results[std_key]:.6f}{unit}")
            elif metric == 'psnr':
                unit = ' dB'
                print(f"{metric.upper():16}: {results[mean_key]:.2f} ± {results[std_key]:.2f}{unit}")
            else:
                print(f"{metric.upper():16}: {results[mean_key]:.6f} ± {results[std_key]:.6f}")

def create_comparative_visualization(all_sample_images, save_path, max_images=8):
    """Create a comparative visualization showing available models"""
    available_models = list(all_sample_images.keys())
    
    if not available_models:
        print("No model results available for visualization")
        return
    
    # Check if we have any images to visualize
    if not all_sample_images[available_models[0]]['inputs']:
        print("No sample images available for visualization")
        return
    
    n_images = min(len(all_sample_images[available_models[0]]['inputs']), max_images)
    n_rows = len(available_models) + 1  # +1 for input row
    
    fig, axes = plt.subplots(n_rows, n_images, figsize=(3*n_images, 3*n_rows))
    if n_images == 1:
        axes = axes.reshape(n_rows, 1)
    
    for i in range(n_images):
        # Input (same for all models)
        inp = all_sample_images[available_models[0]]['inputs'][i].cpu().numpy().transpose(1, 2, 0)
        axes[0, i].imshow(np.clip(inp, 0, 1))
        axes[0, i].set_title(f"Input {i+1}\n(112×112)")
        axes[0, i].axis('off')
        
        # Model predictions
        for j, model_name in enumerate(available_models):
            pred = all_sample_images[model_name]['predictions'][i].cpu().numpy().transpose(1, 2, 0)
            axes[j+1, i].imshow(np.clip(pred, 0, 1))
            axes[j+1, i].set_title(f"{model_name} {i+1}")
            axes[j+1, i].axis('off')
    
    title_models = " → ".join(available_models)
    plt.suptitle(f"Model Comparison: Input → {title_models}", fontsize=16)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

def create_metrics_comparison(all_results, save_dir):
    """Create comprehensive comparison plots with mean ± std format"""
    
    if not all_results:
        print("No results available for metrics comparison")
        return
    
    # Prepare data for comparison
    models = list(all_results.keys())
    metrics = ['psnr', 'ssim', 'ms_ssim', 'mae', 'mse', 'rmse', 'fpr', 'inference_time', 'reconstruction_time']
    
    # 1. Performance comparison bar plots
    fig, axes = plt.subplots(3, 3, figsize=(20, 15))
    axes = axes.flatten()
    
    for i, metric in enumerate(metrics):
        if i >= len(axes):
            break
            
        means = [all_results[model][f'{metric}_mean'] for model in models if f'{metric}_mean' in all_results[model]]
        stds = [all_results[model][f'{metric}_std'] for model in models if f'{metric}_std' in all_results[model]]
        
        if not means:
            continue
        
        bars = axes[i].bar(models, means, yerr=stds, capsize=5, alpha=0.7)
        
        # Color coding: green for "higher is better", red for "lower is better"
        color = 'green' if metric in ['psnr', 'ssim', 'ms_ssim'] else 'red'
        for bar in bars:
            bar.set_color(color)
        
        metric_title = metric.upper().replace('_', ' ')
        axes[i].set_title(f'{metric_title} Comparison (Mean ± Std)')
        axes[i].set_ylabel(f'{metric_title}')
        axes[i].tick_params(axis='x', rotation=45)
        axes[i].grid(True, alpha=0.3)
        
        # Add value labels
        for j, (mean, std) in enumerate(zip(means, stds)):
            axes[i].text(j, mean + std + max(means) * 0.02,
                       f'{mean:.3f}±{std:.3f}',
                       ha='center', va='bottom', fontsize=8)
    
    # Hide unused subplots
    for i in range(len(metrics), len(axes)):
        axes[i].set_visible(False)
    
    plt.suptitle('Model Performance Comparison (Mean ± Standard Deviation)', fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'metrics_comparison_with_std.png'), dpi=150, bbox_inches='tight')
    plt.close()
    
    # 2. Performance vs Speed scatter plot (only if we have multiple models)
    if len(models) > 1:
        plt.figure(figsize=(12, 8))
        colors = ['blue', 'red', 'green', 'orange', 'purple']
        
        for i, model in enumerate(models):
            results = all_results[model]
            plt.scatter(results['reconstruction_time_mean'], results['ssim_mean'], 
                       alpha=0.6, label=model, color=colors[i % len(colors)], s=200, marker='*', 
                       edgecolors='black', linewidth=2)
            
            # Add error bars
            plt.errorbar(results['reconstruction_time_mean'], results['ssim_mean'],
                        xerr=results['reconstruction_time_std'], yerr=results['ssim_std'],
                        fmt='none', color=colors[i % len(colors)], alpha=0.3, capsize=5)
        
        plt.xlabel('Reconstruction Time (seconds)')
        plt.ylabel('SSIM')
        plt.title('Quality vs Speed Trade-off (Mean ± Std)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig(os.path.join(save_dir, 'quality_vs_speed_with_std.png'), dpi=150, bbox_inches='tight')
        plt.close()
    
    return all_results

def main():
    # Configuration using the cleaner list format
    models_config = [
        {
            'type': 'transunet',
            'path': r'F:\ImageNet\transunet_8x8_checkpoints\epoch_20.pth',
            'save_dir': r'F:\ImageNet\evaluation_results\transunet'
        },
        {
            'type': 'unet', 
            'path': r'F:\ImageNet\unet_8x8_checkpoints\epoch_10.pth',
            'save_dir': r'F:\ImageNet\evaluation_results\unet'
        },
        {
            'type': 'trust',
            'path': r'F:\ImageNet\trust_8x8_checkpoints\epoch_80.pth', 
            'save_dir': r'F:\ImageNet\evaluation_results\trust'
        },
        {
            'type': 'restormer',
            'path': r'F:\ImageNet\restormer_8x8_checkpoints\epoch_50.pth',
            'save_dir': r'F:\ImageNet\evaluation_results\restormer'
        }
    ]
    
    config = {
        'test_dir': r"F:\imgnet\data\test",
        'results_dir': "./comprehensive_evaluation_results_four_models",
        'seed': 42,
        'batch_size': 8,
        'max_test_images': 1000,
        'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
    }
    
    print("="*80)
    print("COMPREHENSIVE MODEL COMPARISON WITH MEAN ± STD REPORTING (4 MODELS)")
    print("="*80)
    print(f"Device: {config['device']}")
    print(f"Test directory: {config['test_dir']}")
    print(f"Max test images: {config['max_test_images']}")
    print(f"Batch size: {config['batch_size']}")
    print("Models to evaluate:", [model['type'].upper() for model in models_config])
    
    # Create results directory
    os.makedirs(config['results_dir'], exist_ok=True)
    
    # Load dataset
    test_dataset = PatchwiseOrthonormalDataset(
        data_dir=config['test_dir'], 
        seed=config['seed'], 
        verbose=True
    )
    
    # Limit dataset size if specified
    if config['max_test_images'] and len(test_dataset) > config['max_test_images']:
        indices = np.random.choice(len(test_dataset), config['max_test_images'], replace=False)
        test_dataset.image_files = [test_dataset.image_files[i] for i in sorted(indices)]
        print(f"Limited test set to {config['max_test_images']} images")
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=config['batch_size'], 
        shuffle=False,
        num_workers=4, 
        pin_memory=True
    )
    
    device = torch.device(config['device'])
    
    # Load and evaluate each model
    all_results = {}
    all_sample_images = {}
    all_raw_metrics = {}
    model_info = {}
    
    for model_config in models_config:
        model_type = model_config['type']
        model_path = model_config['path']
        save_dir = model_config['save_dir']
        
        print(f"\n{'-'*50}")
        print(f"Loading {model_type.upper()}...")
        
        # Create individual model save directory
        os.makedirs(save_dir, exist_ok=True)
        
        # Load model based on type
        try:
            if model_type.lower() == 'transunet':
                model = TransUNet(
                    img_size=224,
                    patch_size=16,
                    embed_dim=768,
                    depth=12,
                    num_heads=12,
                    mlp_ratio=4,
                    dropout=0.1
                ).to(device)
            elif model_type.lower() == 'unet':
                model = UNetForInverseProblem(
                    n_channels=3,
                    n_classes=3,
                    bilinear=True
                ).to(device)
            elif model_type.lower() in ['trust', 'vit_unet']:
                model = ViTUNetForInverseProblem(
                    pretrained_model_name="google/vit-base-patch16-224",
                    output_size=(224, 224)
                ).to(device)
            elif model_type.lower() == 'restormer':
                model = RestormerForInverseProblem(
                    inp_channels=3,
                    out_channels=3,
                    dim=48,
                    num_blocks=[4, 6, 6, 8],
                    num_heads=[1, 2, 4, 8],
                    ffn_expansion_factor=2.66,
                    bias=False
                ).to(device)
            else:
                print(f"Unknown model type: {model_type}")
                continue
        except Exception as e:
            print(f"Error creating {model_type} model: {e}")
            continue
        
        # Load checkpoint
        if not os.path.exists(model_path):
            print(f"Warning: Model checkpoint not found: {model_path}")
            continue
            
        try:
            checkpoint = torch.load(model_path, map_location=device)
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"Loaded {model_type.upper()} from epoch {checkpoint.get('epoch', 'unknown')}")
            else:
                model.load_state_dict(checkpoint)
        except Exception as e:
            print(f"Error loading checkpoint for {model_type}: {e}")
            continue
        
        # Model info
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        model_info[model_type.upper()] = {
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'model_path': model_path
        }
        print(f"{model_type.upper()} - Total parameters: {total_params:,}")
        
        # Evaluate model
        try:
            results, sample_images, raw_metrics = evaluate_model(model, test_loader, device, model_type.upper())
            
            # Print individual results
            print_model_results(results, model_type.upper())
            
            all_results[model_type.upper()] = results
            all_sample_images[model_type.upper()] = sample_images
            all_raw_metrics[model_type.upper()] = raw_metrics
            
            # Save individual model results to the specified save_dir
            with open(os.path.join(save_dir, 'evaluation_results.json'), 'w') as f:
                json.dump(results, f, indent=2)
            
            # Save raw metrics for individual model
            with open(os.path.join(save_dir, 'raw_metrics.json'), 'w') as f:
                # Convert raw_metrics to serializable format
                serializable_metrics = {}
                for key, values in raw_metrics.items():
                    serializable_metrics[key] = [float(v) for v in values]
                json.dump(serializable_metrics, f, indent=2)
            
            print(f"Individual results saved to: {save_dir}")
            
        except Exception as e:
            print(f"Error evaluating {model_type}: {e}")
            continue
        
        # Clear GPU memory
        del model
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    # Check if we have any results
    if not all_results:
        print("No models were successfully evaluated. Please check your model paths and configurations.")
        return None
    
    # Generate comparison results
    print("\n" + "="*80)
    print(f"COMPARISON RESULTS SUMMARY (MEAN ± STD) - {len(all_results)} MODELS")
    print("="*80)
    
    # Print comparison table
    print(f"{'Model':<15} {'PSNR (dB)':<15} {'SSIM':<15} {'MAE':<15} {'MSE':<15} {'FPR':<15} {'Time (ms)':<15}")
    print("-" * 105)
    
    for model_name in all_results.keys():
        results = all_results[model_name]
        psnr_str = f"{results['psnr_mean']:.2f}±{results['psnr_std']:.2f}"
        ssim_str = f"{results['ssim_mean']:.3f}±{results['ssim_std']:.3f}"
        mae_str = f"{results['mae_mean']:.3f}±{results['mae_std']:.3f}"
        mse_str = f"{results['mse_mean']:.3f}±{results['mse_std']:.3f}"
        fpr_str = f"{results['fpr_mean']:.3f}±{results['fpr_std']:.3f}"
        time_str = f"{results['inference_time_mean']*1000:.1f}±{results['inference_time_std']*1000:.1f}"
        
        print(f"{model_name:<15} {psnr_str:<15} {ssim_str:<15} {mae_str:<15} {mse_str:<15} {fpr_str:<15} {time_str:<15}")
    
    # Create visualizations
    print("\nCreating comparative visualizations...")
    
    # 1. Comparative model outputs
    create_comparative_visualization(
        all_sample_images,
        os.path.join(config['results_dir'], f'{len(all_results)}_model_comparison_grid.png'),
        max_images=8
    )
    
    # 2. Metrics comparison with mean ± std
    create_metrics_comparison(all_results, config['results_dir'])
    
    # 3. Save comprehensive summary
    with open(os.path.join(config['results_dir'], f'comprehensive_summary_{len(all_results)}_models.json'), 'w') as f:
        json.dump(all_results, f, indent=2)
    
    # 4. Save model configuration and paths
    config_summary = {
        'models_evaluated': [
            {
                'type': model_config['type'],
                'path': model_config['path'],
                'save_dir': model_config['save_dir'],
                'parameters': model_info.get(model_config['type'].upper(), {}).get('total_parameters', 0)
            }
            for model_config in models_config
            if model_config['type'].upper() in all_results
        ],
        'evaluation_config': config,
        'model_info': model_info
    }
    
    with open(os.path.join(config['results_dir'], 'evaluation_config.json'), 'w') as f:
        json.dump(config_summary, f, indent=2)
    
    print(f"\nAll comparison results saved to: {config['results_dir']}")
    print("Generated files:")
    print(f"- comprehensive_summary_{len(all_results)}_models.json: All results with mean ± std")
    print("- evaluation_config.json: Model paths and configuration used")
    print(f"- {len(all_results)}_model_comparison_grid.png: Side-by-side visual comparison")
    print("- metrics_comparison_with_std.png: Performance metrics with error bars")
    if len(all_results) > 1:
        print("- quality_vs_speed_with_std.png: Quality vs speed with error bars")
    print("Individual model results saved to their respective save_dir paths:")
    for model_config in models_config:
        if model_config['type'].upper() in all_results:
            print(f"- {model_config['type'].upper()}: {model_config['save_dir']}")
    
    # Print final ranking (only if we have results)
    if len(all_results) > 1:
        print("\n" + "="*50)
        print(f"MODEL RANKING BY KEY METRICS ({len(all_results)} MODELS)")
        print("="*50)
        
        # Ranking by SSIM (Quality)
        print("By SSIM (Higher is Better):")
        ssim_ranking = sorted(all_results.items(), key=lambda x: x[1]['ssim_mean'], reverse=True)
        for i, (model, results) in enumerate(ssim_ranking, 1):
            print(f"  {i}. {model}: {results['ssim_mean']:.4f} ± {results['ssim_std']:.4f}")
        
        # Ranking by Speed
        print("\nBy Inference Speed (Lower is Better):")
        speed_ranking = sorted(all_results.items(), key=lambda x: x[1]['inference_time_mean'])
        for i, (model, results) in enumerate(speed_ranking, 1):
            time_ms = results['inference_time_mean'] * 1000
            std_ms = results['inference_time_std'] * 1000
            print(f"  {i}. {model}: {time_ms:.1f} ± {std_ms:.1f} ms")
        
        # Ranking by FPR (Hallucination - Lower is Better)
        print("\nBy FPR Score (Lower is Better - Less Hallucination):")
        fpr_ranking = sorted(all_results.items(), key=lambda x: x[1]['fpr_mean'])
        for i, (model, results) in enumerate(fpr_ranking, 1):
            print(f"  {i}. {model}: {results['fpr_mean']:.4f} ± {results['fpr_std']:.4f}")
        
        # Parameter efficiency ranking
        print("\nBy Parameter Efficiency (SSIM per Million Parameters):")
        efficiency_scores = {}
        for model_name, results in all_results.items():
            ssim_val = results['ssim_mean']
            params = model_info[model_name]['total_parameters'] / 1e6  # Convert to millions
            efficiency_scores[model_name] = ssim_val / params
        
        efficiency_ranking = sorted(efficiency_scores.items(), key=lambda x: x[1], reverse=True)
        for i, (model, eff_val) in enumerate(efficiency_ranking, 1):
            params_m = model_info[model]['total_parameters'] / 1e6
            ssim_val = all_results[model]['ssim_mean']
            print(f"  {i}. {model}: {eff_val:.6f} (SSIM: {ssim_val:.4f}, Params: {params_m:.1f}M)")
    
    return all_results

if __name__ == "__main__":
    results = main()