# #!/usr/bin/env python3
# import os, glob, time
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import DataLoader
# import matplotlib.pyplot as plt
# from transformers import ViTModel, ViTConfig
# from pathlib import Path
# from PIL import Image

# # Updated orthonormal matrix generation for 8x8 patch processing
# def make_gaussian_random_orthonormal_rows(h=64, w=64, seed=42):
#     """
#     Generate a matrix A of size [h, w] where rows are orthonormal.
#     Note: Requires h <= w for orthonormal rows to exist.
#     For 8x8 patches, we use 64x64 matrix.
#     """
#     if seed is not None:
#         torch.manual_seed(seed)
#     # Step 1: Random Gaussian matrix
#     A = torch.randn(h, w)
#     # Step 2: QR decomposition on transpose to orthonormalize rows
#     # A^T = Q*R → A = R^T * Q^T
#     # We want orthonormal rows, so we work with A^T first
#     Q, R = torch.linalg.qr(A.T)  # A.T is [w, h]
#     # Q is [w, h] with orthonormal columns
#     # Q.T is [h, w] with orthonormal rows
#     return Q.T  # [h, w] with orthonormal rows

# class PatchwiseOrthonormalDataset:
#     """
#     Dataset that applies patch-wise orthonormal transformation to images.
#     Each 8x8 patch gets transformed through a 64x64 orthonormal matrix.
#     Output: 112x112 masked input → 224x224 original target (super-resolution task)
#     """
#     def __init__(self, data_dir, seed=42, verbose=False):
#         self.data_dir = data_dir
        
#         # Generate fixed orthonormal matrix A for 8x8 patch-wise transformation
#         self.A = make_gaussian_random_orthonormal_rows(h=64, w=64, seed=seed)
        
#         # Get all image files
#         self.data_path = Path(data_dir)
#         if not self.data_path.exists():
#             raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
#         image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPEG', '.JPG'}
#         self.image_files = [f for f in self.data_path.iterdir() 
#                            if f.is_file() and f.suffix in image_extensions]
        
#         if len(self.image_files) == 0:
#             raise ValueError(f"No images found in {data_dir}")
        
#         if verbose:
#             print(f"📁 Loaded {len(self.image_files)} images from {data_dir}")
#             print(f"🔍 Using 8x8 patch-wise orthonormal transformation with matrix shape: {self.A.shape}")
#             print(f"📐 Output: 112x112 masked → 224x224 original (super-resolution task)")

#     def __len__(self):
#         return len(self.image_files)

#     def resize_min_side(self, img, min_side=224):
#         w, h = img.size
#         s = min_side / min(w, h)
#         return img.resize((int(round(w*s)), int(round(h*s))), Image.Resampling.LANCZOS)

#     def center_crop(self, img, size=224):
#         w, h = img.size
#         left = (w - size) // 2
#         top = (h - size) // 2
#         return img.crop((left, top, left + size, top + size))

#     def preprocess_image(self, img):
#         img = img.convert("RGB")
#         img_resized = self.resize_min_side(img, 224)
#         img_crop = self.center_crop(img_resized, 224)
#         x = np.array(img_crop).astype(np.float32) / 255.0
#         return x

#     def process_image_with_orthonormal_masks(self, np_img, mask_matrix):
#         """
#         Apply orthonormal transformation to 8x8 patches of a 224x224 image.
        
#         Args:
#             np_img: numpy array of shape [224, 224, 3]
#             mask_matrix: torch tensor of shape [64, 64] with orthonormal rows
        
#         Returns:
#             transformed_patches: torch tensor of shape [28, 28, 16] (reduced from 64 to 16)
#         """
#         # Convert to torch and extract patches
#         img_tensor = torch.from_numpy(np_img).float()
        
#         # Convert to grayscale for 64 = 8*8
#         if img_tensor.shape[2] == 3:
#             img_gray = img_tensor.mean(dim=2)  # Convert to grayscale
#         else:
#             img_gray = img_tensor
        
#         # Extract 8x8 patches from 224x224 image (28x28 patches total)
#         patches = img_gray.unfold(0, 8, 8).unfold(1, 8, 8)  # [28, 28, 8, 8]
        
#         # Flatten each patch and apply transformation
#         transformed_patches = torch.zeros(28, 28, 16)  # Reduced dimension
        
#         for i in range(28):
#             for j in range(28):
#                 # Flatten 8x8 patch to 64x1
#                 patch_flat = patches[i, j].flatten()  # [64]
                
#                 # Apply orthonormal transformation: [64, 64] @ [64] → [64]
#                 transformed = mask_matrix @ patch_flat
                
#                 # Random subsample to 16 dimensions (as in your code)
#                 transformed = transformed[torch.randperm(transformed.shape[0])[:transformed.shape[0] // 4]]
                
#                 transformed_patches[i, j] = transformed
        
#         return transformed_patches

#     def reconstruct_masked_image(self, transformed_patches):
#         """
#         Reconstruct 112x112 image from 28x28x16 transformed patches.
        
#         Args:
#             transformed_patches: torch tensor of shape [28, 28, 16]
        
#         Returns:
#             masked_image: torch tensor of shape [112, 112]
#         """
#         masked_image = torch.zeros(112, 112)
        
#         for i in range(14):  # Only use first 14x14 patches to get 112x112
#             for j in range(14):
#                 # Get the 16-dimensional transformed patch
#                 transformed_patch = transformed_patches[i, j]  # [16]
                
#                 # Reshape back to 4x4 (since 16 = 4*4)
#                 patch_4x4 = transformed_patch.reshape(4, 4)
                
#                 # Upsample 4x4 to 8x8 for 112x112 output
#                 patch_8x8 = F.interpolate(patch_4x4.unsqueeze(0).unsqueeze(0), 
#                                         size=(8, 8), mode='bilinear', align_corners=True)[0, 0]
                
#                 # Place in the correct position in the 112x112 image
#                 start_h = i * 8
#                 end_h = start_h + 8
#                 start_w = j * 8
#                 end_w = start_w + 8
                
#                 masked_image[start_h:end_h, start_w:end_w] = patch_8x8
        
#         return masked_image

#     def apply_patchwise_orthonormal_transform(self, x):
#         """
#         Apply patch-wise orthonormal transformation to RGB image.
        
#         Args:
#             x: numpy array of shape [224, 224, 3] (original image)
        
#         Returns:
#             y: numpy array of shape [112, 112, 3] (transformed image - downscaled)
#         """
#         y_channels = []
        
#         # Apply transformation to each channel separately
#         for c in range(3):
#             # Create single-channel image for processing
#             single_channel = x[..., c]
            
#             # Apply patch-wise orthonormal transformation
#             transformed_patches = self.process_image_with_orthonormal_masks(
#                 np.expand_dims(single_channel, axis=2), self.A
#             )
            
#             # Reconstruct to 112x112
#             masked_channel = self.reconstruct_masked_image(transformed_patches)
            
#             # Store result
#             y_channels.append(masked_channel.numpy())
        
#         # Stack channels to create [112, 112, 3]
#         y = np.stack(y_channels, axis=2)
        
#         # Normalize to [0, 1] range
#         y_min = y.min()
#         y_max = y.max()
#         y_norm = (y - y_min) / (y_max - y_min + 1e-8)
        
#         return y_norm

#     def __getitem__(self, idx):
#         img_path = self.image_files[idx]
        
#         try:
#             img = Image.open(img_path)
#         except Exception as e:
#             print(f"Warning: Could not load image {img_path}: {e}")
#             img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        
#         # Preprocess to get x (target - original 224x224 image)
#         x = self.preprocess_image(img)  # (224, 224, 3)
        
#         # Apply patch-wise orthonormal transform to get y (input - 112x112 masked image)
#         y = self.apply_patchwise_orthonormal_transform(x)  # (112, 112, 3)
        
#         # Convert to torch tensors and change to CHW format
#         x_tensor = torch.from_numpy(x).permute(2, 0, 1)  # (3, 224, 224) - target
#         y_tensor = torch.from_numpy(y).permute(2, 0, 1)  # (3, 112, 112) - input
        
#         return y_tensor, x_tensor  # (input 112x112, target 224x224)

# # ViT-UNet Model (same as before)
# class ViTUNetForInverseProblem(nn.Module):
#     def __init__(self, pretrained_model_name="google/vit-base-patch16-224", output_size=(224, 224)):
#         super().__init__()
    
#         # Load config and explicitly disable pooling layer
#         cfg = ViTConfig.from_pretrained(pretrained_model_name)
#         cfg.add_pooling_layer = False  # Explicitly disable pooler
    
#         # Load ViT without pooler to avoid unused parameters
#         self.vit = ViTModel.from_pretrained(pretrained_model_name, config=cfg, ignore_mismatched_sizes=True)
    
#         self.output_size = output_size
#         self.hidden_dim = 768

#         # Input upsampling layer to handle 112x112 → 224x224 for ViT
#         self.input_upsample = nn.Sequential(
#             nn.Conv2d(3, 16, 3, padding=1),
#             nn.BatchNorm2d(16), nn.ReLU(True),
#             nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
#             nn.Conv2d(16, 3, 3, padding=1),
#             nn.Sigmoid()
#         )

#         # Big skip connection: 112x112 → 224x224 
#         self.skip_upsample = nn.Sequential(
#             nn.Conv2d(3, 32, 3, padding=1),
#             nn.BatchNorm2d(32), nn.ReLU(True),
#             nn.Conv2d(32, 64, 3, padding=1), 
#             nn.BatchNorm2d(64), nn.ReLU(True),
#             nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
#             nn.Conv2d(64, 32, 3, padding=1),
#             nn.BatchNorm2d(32), nn.ReLU(True),
#             nn.Conv2d(32, 3, 3, padding=1),  # Same number of channels as output
#             nn.Tanh()  # Use tanh to allow both positive and negative contributions
#         )

#         # Feature pools - smaller since we're working with 224x224
#         self.adaptive_pool1 = nn.AdaptiveAvgPool2d((28, 28))   # ~1/8 of 224
#         self.adaptive_pool2 = nn.AdaptiveAvgPool2d((14, 14))   # ~1/16 of 224  
#         self.adaptive_pool3 = nn.AdaptiveAvgPool2d((7, 7))     # ~1/32 of 224
#         self.adaptive_pool_final = nn.AdaptiveAvgPool2d((7, 7))

#         self.skip_conn1 = nn.Conv2d(self.hidden_dim, 128, kernel_size=1)
#         self.skip_conn2 = nn.Conv2d(self.hidden_dim, 256, kernel_size=1)
#         self.skip_conn3 = nn.Conv2d(self.hidden_dim, 512, kernel_size=1)

#         # Decoder path
#         self.up1 = nn.Sequential(
#             nn.Upsample(size=(14, 14), mode='bilinear', align_corners=True),
#             nn.Conv2d(self.hidden_dim, 512, 3, padding=1), 
#             nn.BatchNorm2d(512), nn.ReLU(True)
#         )
#         self.up2 = nn.Sequential(
#             nn.Upsample(size=(28, 28), mode='bilinear', align_corners=True),
#             nn.Conv2d(512, 256, 3, padding=1), 
#             nn.BatchNorm2d(256), nn.ReLU(True)
#         )
#         self.up3 = nn.Sequential(
#             nn.Upsample(size=(56, 56), mode='bilinear', align_corners=True),
#             nn.Conv2d(256, 128, 3, padding=1), 
#             nn.BatchNorm2d(128), nn.ReLU(True)
#         )
#         self.up4 = nn.Sequential(
#             nn.Upsample(size=(112, 112), mode='bilinear', align_corners=True),
#             nn.Conv2d(128, 64, 3, padding=1), 
#             nn.BatchNorm2d(64), nn.ReLU(True)
#         )
    
#         # Final output layer - RGB output with super-resolution
#         self.final = nn.Sequential(
#             nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
#             nn.Conv2d(64, 32, 3, padding=1),
#             nn.BatchNorm2d(32), nn.ReLU(True),
#             nn.Conv2d(32, 3, 3, padding=1),  # 3 channels for RGB
#             nn.Sigmoid()  # Output in [0, 1] range
#         )

#         # Fusion layer to combine ViT output with skip connection
#         self.fusion = nn.Sequential(
#             nn.Conv2d(6, 32, 3, padding=1),  # 6 channels: 3 from ViT + 3 from skip
#             nn.BatchNorm2d(32), nn.ReLU(True),
#             nn.Conv2d(32, 3, 3, padding=1),
#             nn.Sigmoid()
#         )

#     def _extract(self, x3):
#         out = self.vit(x3, output_hidden_states=True)
#         early = out.hidden_states[3]   # Early features
#         mid   = out.hidden_states[6]   # Mid features  
#         late  = out.hidden_states[9]   # Late features
#         last  = out.last_hidden_state  # Final features
#         return early, mid, late, last

#     def _to_spatial(self, tokens):
#         B, N, C = tokens.shape
#         HW = int((N - 1) ** 0.5)  # Drop CLS token
#         t = tokens[:, 1:, :].permute(0, 2, 1)
#         return t.reshape(B, C, HW, HW)

#     def forward(self, x):
#         # x is [B, 3, 112, 112] input
        
#         # Create skip connection: 112x112 → 224x224
#         skip_features = self.skip_upsample(x)  # [B, 3, 224, 224]
        
#         # Process through ViT-UNet
#         if x.shape[-2:] != (224, 224):
#             x_upsampled = self.input_upsample(x)  # 112x112 → 224x224
#         else:
#             x_upsampled = x

#         # Extract multi-scale features from ViT
#         e, m, l, f = self._extract(x_upsampled)
#         e, m, l, f = map(self._to_spatial, (e, m, l, f))

#         # Create skip connections
#         skip1 = self.adaptive_pool1(e)   # 28x28
#         skip2 = self.adaptive_pool2(m)   # 14x14
#         skip3 = self.adaptive_pool3(l)   # 7x7
#         x_feat = self.adaptive_pool_final(f)  # 7x7

#         # Decoder with skip connections
#         x_feat = self.up1(x_feat)  # 7x7 -> 14x14
#         x_feat = x_feat + F.interpolate(self.skip_conn3(skip3), size=(14, 14), mode='bilinear', align_corners=True)
    
#         x_feat = self.up2(x_feat)  # 14x14 -> 28x28
#         x_feat = x_feat + F.interpolate(self.skip_conn2(skip2), size=(28, 28), mode='bilinear', align_corners=True)
    
#         x_feat = self.up3(x_feat)  # 28x28 -> 56x56
#         x_feat = x_feat + F.interpolate(self.skip_conn1(skip1), size=(56, 56), mode='bilinear', align_corners=True)
    
#         x_feat = self.up4(x_feat)  # 56x56 -> 112x112
    
#         # ViT-UNet output
#         vit_output = self.final(x_feat)  # 112x112 -> 224x224, 3 channels
        
#         # Combine ViT output with skip connection
#         combined = torch.cat([vit_output, skip_features], dim=1)  # [B, 6, 224, 224]
        
#         # Final fusion
#         out = self.fusion(combined)  # [B, 3, 224, 224]
    
#         # Resize to target if needed
#         if self.output_size != (224, 224):
#             out = F.interpolate(out, size=self.output_size, mode='bilinear', align_corners=True)
        
#         return out

# # Simple MSE Loss
# class MSELoss(nn.Module):
#     def forward(self, pred, target):
#         return F.mse_loss(pred, target)

# # Training function
# def run_epoch(loader, model, optim, device, criterion, train=True, visualize_every=0, vis_dir=None):
#     if train:
#         model.train()
#     else:
#         model.eval()

#     total_loss = 0.0
#     num_batches = 0

#     with torch.set_grad_enabled(train):
#         for i, (y_batch, x_batch) in enumerate(loader):
#             # y_batch: transformed images (input), x_batch: original images (target)
#             y_batch = y_batch.to(device, non_blocking=True)
#             x_batch = x_batch.to(device, non_blocking=True)

#             # Forward pass: predict x from y
#             pred_x = model(y_batch)
        
#             # Ensure shapes match
#             if pred_x.shape != x_batch.shape:
#                 pred_x = F.interpolate(pred_x, size=x_batch.shape[-2:], mode='bilinear', align_corners=True)

#             # Compute loss
#             loss = criterion(pred_x, x_batch)
        
#             if train:
#                 optim.zero_grad(set_to_none=True)
#                 loss.backward()
#                 optim.step()

#             total_loss += loss.item()
#             num_batches += 1

#             # Visualization during validation
#             if (not train) and visualize_every and (i % visualize_every == 0) and vis_dir:
#                 os.makedirs(vis_dir, exist_ok=True)
            
#                 # Convert to numpy for visualization
#                 y_vis = y_batch[0].detach().cpu().numpy().transpose(1, 2, 0)  # CHW -> HWC
#                 pred_vis = pred_x[0].detach().cpu().numpy().transpose(1, 2, 0)
#                 x_vis = x_batch[0].detach().cpu().numpy().transpose(1, 2, 0)
            
#                 fig, ax = plt.subplots(1, 3, figsize=(15, 5))
#                 ax[0].imshow(np.clip(y_vis, 0, 1))
#                 ax[0].set_title("Input (112x112 Masked)")
#                 ax[0].axis('off')
            
#                 ax[1].imshow(np.clip(pred_vis, 0, 1))
#                 ax[1].set_title("Predicted (224x224)")
#                 ax[1].axis('off')
            
#                 ax[2].imshow(np.clip(x_vis, 0, 1))
#                 ax[2].set_title("Target (224x224 Original)")
#                 ax[2].axis('off')
            
#                 plt.tight_layout()
#                 plt.savefig(os.path.join(vis_dir, f"val_{i:05d}.png"), dpi=120, bbox_inches='tight')
#                 plt.close()

#     return total_loss / max(1, num_batches)

# def main():
#     args = {
#         # Data paths (update these to your paths)
#         'train_dir': r"F:\imgnet\data\train",
#         'val_dir': r"F:\imgnet\data\val",
#         'save_dir': "./trust_8x8_checkpoints",
#         'load_path': "",  # Path to checkpoint to resume from (optional)
    
#         # Model settings
#         'target_size': (224, 224),  # Output image size
#         'seed': 42,  # Fixed seed for orthonormal matrix A
    
#         # Training settings
#         'batch_size': 16,  # Reduced for single GPU
#         'lr': 1e-4,
#         'epochs': 100,
#         'save_every': 10,
#         'viz_every': 50,  # Visualize every N validation batches
#     }

#     # Set device
#     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#     print(f"🔍 Using device: {device}")

#     # Create save directory
#     os.makedirs(args['save_dir'], exist_ok=True)

#     # Save args for reference
#     import json
#     with open(os.path.join(args['save_dir'], 'training_args.json'), 'w') as f:
#         json.dump(args, f, indent=2)

#     # Create datasets
#     train_ds = PatchwiseOrthonormalDataset(
#         data_dir=args['train_dir'],
#         seed=args['seed'],
#         verbose=True
#     )

#     val_ds = PatchwiseOrthonormalDataset(
#         data_dir=args['val_dir'],
#         seed=args['seed'],  # Same seed for consistent A matrix
#         verbose=True
#     )

#     # Create dataloaders
#     train_loader = DataLoader(
#         train_ds, batch_size=args['batch_size'], shuffle=True,
#         num_workers=4, pin_memory=True, drop_last=True
#     )
#     val_loader = DataLoader(
#         val_ds, batch_size=args['batch_size'], shuffle=False,
#         num_workers=4, pin_memory=True, drop_last=False
#     )

#     # Create model
#     model = ViTUNetForInverseProblem(output_size=args['target_size']).to(device)

#     # Load checkpoint if provided
#     if args['load_path'] and os.path.isfile(args['load_path']):
#         ckpt = torch.load(args['load_path'], map_location=device)
#         state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
#         model.load_state_dict(state_dict, strict=False)
#         print(f"[GPU] Loaded checkpoint: {args['load_path']}")

#     # Optimizer and scheduler
#     optimizer = torch.optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=1e-4)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer, mode='min', factor=0.5, patience=5, verbose=True
#     )
#     criterion = MSELoss()

#     print(f"🚀 Starting 8x8 patch-wise orthonormal super-resolution training")
#     print(f"📊 Train: {len(train_ds)} images, Val: {len(val_ds)} images")
#     print(f"🔢 Batch size: {args['batch_size']}")
#     print(f"🎯 Task: 112x112 masked input → 224x224 original target (2x super-resolution)")

#     best_val_loss = float('inf')

#     for epoch in range(args['epochs']):
#         t0 = time.time()
    
#         # Training
#         train_loss = run_epoch(train_loader, model, optimizer, device, criterion, train=True)
    
#         # Validation
#         val_loss = run_epoch(
#             val_loader, model, optimizer, device, criterion, train=False,
#             visualize_every=args['viz_every'],
#             vis_dir=os.path.join(args['save_dir'], "val_vis")
#         )
    
#         scheduler.step(val_loss)

#         elapsed = time.time() - t0
#         lr = optimizer.param_groups[0]['lr']
#         print(f"Epoch {epoch+1:03d}/{args['epochs']:03d} | "
#               f"Train: {train_loss:.6f} | Val: {val_loss:.6f} | "
#               f"Time: {elapsed:.1f}s | LR: {lr:.2e}")
        
#         # Save checkpoints
#         os.makedirs(args['save_dir'], exist_ok=True)
        
#         # Save best model
#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             torch.save({
#                 'epoch': epoch + 1,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'val_loss': val_loss,
#                 'args': args
#             }, os.path.join(args['save_dir'], "best_model.pth"))
        
#         # Save periodic checkpoints
#         if (epoch + 1) % args['save_every'] == 0:
#             torch.save({
#                 'epoch': epoch + 1,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'val_loss': val_loss,
#                 'args': args
#             }, os.path.join(args['save_dir'], f"epoch_{epoch+1}.pth"))

#     print("🎉 Training completed!")

# if __name__ == "__main__":
#     torch.cuda.empty_cache()
#     main()


#!/usr/bin/env python3
import os, glob, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from transformers import ViTModel, ViTConfig
from pathlib import Path
from PIL import Image

# Install this first: pip install pytorch-msssim
from pytorch_msssim import ssim, ms_ssim

# Updated orthonormal matrix generation for 8x8 patch processing
def make_gaussian_random_orthonormal_rows(h=64, w=64, seed=42):
    """
    Generate a matrix A of size [h, w] where rows are orthonormal.
    Note: Requires h <= w for orthonormal rows to exist.
    For 8x8 patches, we use 64x64 matrix.
    """
    if seed is not None:
        torch.manual_seed(seed)
    # Step 1: Random Gaussian matrix
    A = torch.randn(h, w)
    # Step 2: QR decomposition on transpose to orthonormalize rows
    # A^T = Q*R → A = R^T * Q^T
    # We want orthonormal rows, so we work with A^T first
    Q, R = torch.linalg.qr(A.T)  # A.T is [w, h]
    # Q is [w, h] with orthonormal columns
    # Q.T is [h, w] with orthonormal rows
    return Q.T  # [h, w] with orthonormal rows

class PatchwiseOrthonormalDataset:
    """
    Dataset that applies patch-wise orthonormal transformation to images.
    Each 8x8 patch gets transformed through a 64x64 orthonormal matrix.
    Output: 112x112 masked input → 224x224 original target (super-resolution task)
    """
    def __init__(self, data_dir, seed=42, verbose=False):
        self.data_dir = data_dir
        
        # Generate fixed orthonormal matrix A for 8x8 patch-wise transformation
        self.A = make_gaussian_random_orthonormal_rows(h=64, w=64, seed=seed)
        
        # Get all image files
        self.data_path = Path(data_dir)
        if not self.data_path.exists():
            raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPEG', '.JPG'}
        self.image_files = [f for f in self.data_path.iterdir() 
                           if f.is_file() and f.suffix in image_extensions]
        
        if len(self.image_files) == 0:
            raise ValueError(f"No images found in {data_dir}")
        
        if verbose:
            print(f"📁 Loaded {len(self.image_files)} images from {data_dir}")
            print(f"🔍 Using 8x8 patch-wise orthonormal transformation with matrix shape: {self.A.shape}")
            print(f"📐 Output: 112x112 masked → 224x224 original (super-resolution task)")

    def __len__(self):
        return len(self.image_files)

    def resize_min_side(self, img, min_side=224):
        w, h = img.size
        s = min_side / min(w, h)
        return img.resize((int(round(w*s)), int(round(h*s))), Image.Resampling.LANCZOS)

    def center_crop(self, img, size=224):
        w, h = img.size
        left = (w - size) // 2
        top = (h - size) // 2
        return img.crop((left, top, left + size, top + size))

    def preprocess_image(self, img):
        img = img.convert("RGB")
        img_resized = self.resize_min_side(img, 224)
        img_crop = self.center_crop(img_resized, 224)
        x = np.array(img_crop).astype(np.float32) / 255.0
        return x

    def process_image_with_orthonormal_masks(self, np_img, mask_matrix):
        """
        Apply orthonormal transformation to 8x8 patches of a 224x224 image.
        
        Args:
            np_img: numpy array of shape [224, 224, 3]
            mask_matrix: torch tensor of shape [64, 64] with orthonormal rows
        
        Returns:
            transformed_patches: torch tensor of shape [28, 28, 16] (reduced from 64 to 16)
        """
        # Convert to torch and extract patches
        img_tensor = torch.from_numpy(np_img).float()
        
        # Convert to grayscale for 64 = 8*8
        if img_tensor.shape[2] == 3:
            img_gray = img_tensor.mean(dim=2)  # Convert to grayscale
        else:
            img_gray = img_tensor
        
        # Extract 8x8 patches from 224x224 image (28x28 patches total)
        patches = img_gray.unfold(0, 8, 8).unfold(1, 8, 8)  # [28, 28, 8, 8]
        
        # Flatten each patch and apply transformation
        transformed_patches = torch.zeros(28, 28, 16)  # Reduced dimension
        
        for i in range(28):
            for j in range(28):
                # Flatten 8x8 patch to 64x1
                patch_flat = patches[i, j].flatten()  # [64]
                
                # Apply orthonormal transformation: [64, 64] @ [64] → [64]
                transformed = mask_matrix @ patch_flat
                
                # Random subsample to 16 dimensions (as in your code)
                transformed = transformed[torch.randperm(transformed.shape[0])[:transformed.shape[0] // 4]]
                
                transformed_patches[i, j] = transformed
        
        return transformed_patches

    def reconstruct_masked_image(self, transformed_patches):
        """
        Reconstruct 112x112 image from 28x28x16 transformed patches.
        
        Args:
            transformed_patches: torch tensor of shape [28, 28, 16]
        
        Returns:
            masked_image: torch tensor of shape [112, 112]
        """
        masked_image = torch.zeros(112, 112)
        
        for i in range(14):  # Only use first 14x14 patches to get 112x112
            for j in range(14):
                # Get the 16-dimensional transformed patch
                transformed_patch = transformed_patches[i, j]  # [16]
                
                # Reshape back to 4x4 (since 16 = 4*4)
                patch_4x4 = transformed_patch.reshape(4, 4)
                
                # Upsample 4x4 to 8x8 for 112x112 output
                patch_8x8 = F.interpolate(patch_4x4.unsqueeze(0).unsqueeze(0), 
                                        size=(8, 8), mode='bilinear', align_corners=True)[0, 0]
                
                # Place in the correct position in the 112x112 image
                start_h = i * 8
                end_h = start_h + 8
                start_w = j * 8
                end_w = start_w + 8
                
                masked_image[start_h:end_h, start_w:end_w] = patch_8x8
        
        return masked_image

    def apply_patchwise_orthonormal_transform(self, x):
        """
        Apply patch-wise orthonormal transformation to RGB image.
        
        Args:
            x: numpy array of shape [224, 224, 3] (original image)
        
        Returns:
            y: numpy array of shape [112, 112, 3] (transformed image - downscaled)
        """
        y_channels = []
        
        # Apply transformation to each channel separately
        for c in range(3):
            # Create single-channel image for processing
            single_channel = x[..., c]
            
            # Apply patch-wise orthonormal transformation
            transformed_patches = self.process_image_with_orthonormal_masks(
                np.expand_dims(single_channel, axis=2), self.A
            )
            
            # Reconstruct to 112x112
            masked_channel = self.reconstruct_masked_image(transformed_patches)
            
            # Store result
            y_channels.append(masked_channel.numpy())
        
        # Stack channels to create [112, 112, 3]
        y = np.stack(y_channels, axis=2)
        
        # Normalize to [0, 1] range
        y_min = y.min()
        y_max = y.max()
        y_norm = (y - y_min) / (y_max - y_min + 1e-8)
        
        return y_norm

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        
        # Preprocess to get x (target - original 224x224 image)
        x = self.preprocess_image(img)  # (224, 224, 3)
        
        # Apply patch-wise orthonormal transform to get y (input - 112x112 masked image)
        y = self.apply_patchwise_orthonormal_transform(x)  # (112, 112, 3)
        
        # Convert to torch tensors and change to CHW format
        x_tensor = torch.from_numpy(x).permute(2, 0, 1)  # (3, 224, 224) - target
        y_tensor = torch.from_numpy(y).permute(2, 0, 1)  # (3, 112, 112) - input
        
        return y_tensor, x_tensor  # (input 112x112, target 224x224)

# ViT-UNet Model
class ViTUNetForInverseProblem(nn.Module):
    def __init__(self, pretrained_model_name="google/vit-base-patch16-224", output_size=(224, 224)):
        super().__init__()
    
        # Load config and explicitly disable pooling layer
        cfg = ViTConfig.from_pretrained(pretrained_model_name)
        cfg.add_pooling_layer = False  # Explicitly disable pooler
    
        # Load ViT without pooler to avoid unused parameters
        self.vit = ViTModel.from_pretrained(pretrained_model_name, config=cfg, ignore_mismatched_sizes=True)
    
        self.output_size = output_size
        self.hidden_dim = 768

        # Input upsampling layer to handle 112x112 → 224x224 for ViT
        self.input_upsample = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.BatchNorm2d(16), nn.ReLU(True),
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
            nn.Conv2d(16, 3, 3, padding=1),
            nn.Sigmoid()
        )

        # Big skip connection: 112x112 → 224x224 
        self.skip_upsample = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 64, 3, padding=1), 
            nn.BatchNorm2d(64), nn.ReLU(True),
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 3, 3, padding=1),  # Same number of channels as output
            nn.Tanh()  # Use tanh to allow both positive and negative contributions
        )

        # Feature pools - smaller since we're working with 224x224
        self.adaptive_pool1 = nn.AdaptiveAvgPool2d((28, 28))   # ~1/8 of 224
        self.adaptive_pool2 = nn.AdaptiveAvgPool2d((14, 14))   # ~1/16 of 224  
        self.adaptive_pool3 = nn.AdaptiveAvgPool2d((7, 7))     # ~1/32 of 224
        self.adaptive_pool_final = nn.AdaptiveAvgPool2d((7, 7))

        self.skip_conn1 = nn.Conv2d(self.hidden_dim, 128, kernel_size=1)
        self.skip_conn2 = nn.Conv2d(self.hidden_dim, 256, kernel_size=1)
        self.skip_conn3 = nn.Conv2d(self.hidden_dim, 512, kernel_size=1)

        # Decoder path
        self.up1 = nn.Sequential(
            nn.Upsample(size=(14, 14), mode='bilinear', align_corners=True),
            nn.Conv2d(self.hidden_dim, 512, 3, padding=1), 
            nn.BatchNorm2d(512), nn.ReLU(True)
        )
        self.up2 = nn.Sequential(
            nn.Upsample(size=(28, 28), mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, 3, padding=1), 
            nn.BatchNorm2d(256), nn.ReLU(True)
        )
        self.up3 = nn.Sequential(
            nn.Upsample(size=(56, 56), mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, 3, padding=1), 
            nn.BatchNorm2d(128), nn.ReLU(True)
        )
        self.up4 = nn.Sequential(
            nn.Upsample(size=(112, 112), mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, 3, padding=1), 
            nn.BatchNorm2d(64), nn.ReLU(True)
        )
    
        # Final output layer - RGB output with super-resolution
        self.final = nn.Sequential(
            nn.Upsample(size=(224, 224), mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 3, 3, padding=1),  # 3 channels for RGB
            nn.Sigmoid()  # Output in [0, 1] range
        )

        # Fusion layer to combine ViT output with skip connection
        self.fusion = nn.Sequential(
            nn.Conv2d(6, 32, 3, padding=1),  # 6 channels: 3 from ViT + 3 from skip
            nn.BatchNorm2d(32), nn.ReLU(True),
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Sigmoid()
        )

    def _extract(self, x3):
        out = self.vit(x3, output_hidden_states=True)
        early = out.hidden_states[3]   # Early features
        mid   = out.hidden_states[6]   # Mid features  
        late  = out.hidden_states[9]   # Late features
        last  = out.last_hidden_state  # Final features
        return early, mid, late, last

    def _to_spatial(self, tokens):
        B, N, C = tokens.shape
        HW = int((N - 1) ** 0.5)  # Drop CLS token
        t = tokens[:, 1:, :].permute(0, 2, 1)
        return t.reshape(B, C, HW, HW)

    def forward(self, x):
        # x is [B, 3, 112, 112] input
        
        # Create skip connection: 112x112 → 224x224
        skip_features = self.skip_upsample(x)  # [B, 3, 224, 224]
        
        # Process through ViT-UNet
        if x.shape[-2:] != (224, 224):
            x_upsampled = self.input_upsample(x)  # 112x112 → 224x224
        else:
            x_upsampled = x

        # Extract multi-scale features from ViT
        e, m, l, f = self._extract(x_upsampled)
        e, m, l, f = map(self._to_spatial, (e, m, l, f))

        # Create skip connections
        skip1 = self.adaptive_pool1(e)   # 28x28
        skip2 = self.adaptive_pool2(m)   # 14x14
        skip3 = self.adaptive_pool3(l)   # 7x7
        x_feat = self.adaptive_pool_final(f)  # 7x7

        # Decoder with skip connections
        x_feat = self.up1(x_feat)  # 7x7 -> 14x14
        x_feat = x_feat + F.interpolate(self.skip_conn3(skip3), size=(14, 14), mode='bilinear', align_corners=True)
    
        x_feat = self.up2(x_feat)  # 14x14 -> 28x28
        x_feat = x_feat + F.interpolate(self.skip_conn2(skip2), size=(28, 28), mode='bilinear', align_corners=True)
    
        x_feat = self.up3(x_feat)  # 28x28 -> 56x56
        x_feat = x_feat + F.interpolate(self.skip_conn1(skip1), size=(56, 56), mode='bilinear', align_corners=True)
    
        x_feat = self.up4(x_feat)  # 56x56 -> 112x112
    
        # ViT-UNet output
        vit_output = self.final(x_feat)  # 112x112 -> 224x224, 3 channels
        
        # Combine ViT output with skip connection
        combined = torch.cat([vit_output, skip_features], dim=1)  # [B, 6, 224, 224]
        
        # Final fusion
        out = self.fusion(combined)  # [B, 3, 224, 224]
    
        # Resize to target if needed
        if self.output_size != (224, 224):
            out = F.interpolate(out, size=self.output_size, mode='bilinear', align_corners=True)
        
        return out

# Combined Loss Function
class CombinedLoss(nn.Module):
    """
    Combined loss function with MSE, L1, and SSIM terms.
    
    Args:
        mse_weight (float): Weight for MSE loss term
        l1_weight (float): Weight for L1 loss term  
        ssim_weight (float): Weight for SSIM loss term
        use_ms_ssim (bool): Whether to use multi-scale SSIM instead of regular SSIM
    """
    def __init__(self, mse_weight=1.0, l1_weight=1.0, ssim_weight=1.0, use_ms_ssim=False):
        super().__init__()
        self.mse_weight = mse_weight
        self.l1_weight = l1_weight
        self.ssim_weight = ssim_weight
        self.use_ms_ssim = use_ms_ssim
        
        # MSE and L1 losses
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        
        print(f"🎯 Combined Loss - MSE: {mse_weight}, L1: {l1_weight}, SSIM: {ssim_weight}")
        if use_ms_ssim:
            print("📐 Using Multi-Scale SSIM")
        else:
            print("📐 Using Standard SSIM")
    
    def forward(self, pred, target):
        """
        Calculate combined loss.
        
        Args:
            pred: Predicted images [B, C, H, W]
            target: Target images [B, C, H, W]
            
        Returns:
            Dictionary containing individual losses and total loss
        """
        # Ensure inputs are in [0, 1] range for SSIM
        pred_clamped = torch.clamp(pred, 0, 1)
        target_clamped = torch.clamp(target, 0, 1)
        
        # Calculate individual losses
        mse = self.mse_loss(pred, target)
        l1 = self.l1_loss(pred, target)
        
        # Calculate SSIM loss (1 - SSIM since we want to minimize)
        if self.use_ms_ssim:
            ssim_val = ms_ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True)
        else:
            ssim_val = ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True)
        
        ssim_loss = 1 - ssim_val
        
        # Combined loss
        total_loss = (self.mse_weight * mse + 
                     self.l1_weight * l1 + 
                     self.ssim_weight * ssim_loss)
        
        # Return dictionary for monitoring
        return {
            'total_loss': total_loss,
            'mse_loss': mse,
            'l1_loss': l1,
            'ssim_loss': ssim_loss,
            'ssim_value': ssim_val  # Higher is better for SSIM
        }

# Perceptual Loss (Optional)
class PerceptualLoss(nn.Module):
    """
    Optional: Perceptual loss using VGG features for even better reconstruction quality.
    """
    def __init__(self, feature_layers=[3, 8, 15, 22]):
        super().__init__()
        # Load pre-trained VGG16
        vgg = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True)
        self.features = vgg.features
        self.feature_layers = feature_layers
        
        # Freeze VGG parameters
        for param in self.features.parameters():
            param.requires_grad = False
            
        # Normalization for ImageNet pre-trained model
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    def forward(self, pred, target):
        # Normalize inputs
        pred_norm = (pred - self.mean) / self.std
        target_norm = (target - self.mean) / self.std
        
        # Extract features
        pred_features = self.extract_features(pred_norm)
        target_features = self.extract_features(target_norm)
        
        # Calculate perceptual loss
        loss = 0
        for pred_feat, target_feat in zip(pred_features, target_features):
            loss += F.mse_loss(pred_feat, target_feat)
        
        return loss / len(pred_features)
    
    def extract_features(self, x):
        features = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in self.feature_layers:
                features.append(x)
        return features

# Advanced Combined Loss with optional perceptual term
class AdvancedCombinedLoss(nn.Module):
    """
    Advanced loss combining MSE, L1, SSIM, and optional Perceptual loss.
    """
    def __init__(self, mse_weight=1.0, l1_weight=1.0, ssim_weight=1.0, 
                 perceptual_weight=0.0, use_ms_ssim=False):
        super().__init__()
        self.mse_weight = mse_weight
        self.l1_weight = l1_weight
        self.ssim_weight = ssim_weight
        self.perceptual_weight = perceptual_weight
        self.use_ms_ssim = use_ms_ssim
        
        # Basic losses
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        
        # Perceptual loss (optional)
        if perceptual_weight > 0:
            self.perceptual_loss = PerceptualLoss()
            print(f"🎨 Including Perceptual Loss with weight: {perceptual_weight}")
        
        print(f"🎯 Advanced Combined Loss - MSE: {mse_weight}, L1: {l1_weight}, SSIM: {ssim_weight}")
    
    def forward(self, pred, target):
        # Ensure inputs are in [0, 1] range for SSIM
        pred_clamped = torch.clamp(pred, 0, 1)
        target_clamped = torch.clamp(target, 0, 1)
        
        # Calculate individual losses
        mse = self.mse_loss(pred, target)
        l1 = self.l1_loss(pred, target)
        
        # SSIM loss
        if self.use_ms_ssim:
            ssim_val = ms_ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True)
        else:
            ssim_val = ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True)
        
        ssim_loss = 1 - ssim_val
        
        # Perceptual loss (if enabled)
        perceptual = torch.tensor(0.0, device=pred.device)
        if self.perceptual_weight > 0:
            perceptual = self.perceptual_loss(pred_clamped, target_clamped)
        
        # Combined loss
        total_loss = (self.mse_weight * mse + 
                     self.l1_weight * l1 + 
                     self.ssim_weight * ssim_loss +
                     self.perceptual_weight * perceptual)
        
        return {
            'total_loss': total_loss,
            'mse_loss': mse,
            'l1_loss': l1,
            'ssim_loss': ssim_loss,
            'ssim_value': ssim_val,
            'perceptual_loss': perceptual
        }

# Loss Configuration Presets
def get_loss_configs():
    """
    Different loss configurations for different training strategies.
    """
    configs = {
        # Basic combination - good starting point
        'basic': {
            'mse_weight': 1.0,
            'l1_weight': 1.0, 
            'ssim_weight': 1.0,
            'use_ms_ssim': False
        },
        
        # SSIM-focused - better perceptual quality
        'ssim_focused': {
            'mse_weight': 0.5,
            'l1_weight': 1.0,
            'ssim_weight': 2.0,
            'use_ms_ssim': True
        },
        
        # L1-focused - sharper edges
        'sharp_edges': {
            'mse_weight': 0.5,
            'l1_weight': 2.0,
            'ssim_weight': 1.0,
            'use_ms_ssim': False
        },
        
        # Balanced with perceptual
        'perceptual': {
            'mse_weight': 1.0,
            'l1_weight': 1.0,
            'ssim_weight': 1.0,
            'perceptual_weight': 0.1,
            'use_ms_ssim': True
        }
    }
    return configs

# Training function
def run_epoch(loader, model, optim, device, criterion, train=True, visualize_every=0, vis_dir=None):
    if train:
        model.train()
    else:
        model.eval()

    # Track individual loss components
    total_loss = 0.0
    total_mse = 0.0
    total_l1 = 0.0
    total_ssim_loss = 0.0
    total_ssim_value = 0.0
    total_perceptual = 0.0
    num_batches = 0

    with torch.set_grad_enabled(train):
        for i, (y_batch, x_batch) in enumerate(loader):
            # y_batch: transformed images (input), x_batch: original images (target)
            y_batch = y_batch.to(device, non_blocking=True)
            x_batch = x_batch.to(device, non_blocking=True)

            # Forward pass: predict x from y
            pred_x = model(y_batch)
        
            # Ensure shapes match
            if pred_x.shape != x_batch.shape:
                pred_x = F.interpolate(pred_x, size=x_batch.shape[-2:], mode='bilinear', align_corners=True)

            # Compute combined loss
            loss_dict = criterion(pred_x, x_batch)
            loss = loss_dict['total_loss']
        
            if train:
                optim.zero_grad(set_to_none=True)
                loss.backward()
                optim.step()

            # Accumulate losses for monitoring
            total_loss += loss.item()
            total_mse += loss_dict['mse_loss'].item()
            total_l1 += loss_dict['l1_loss'].item()
            total_ssim_loss += loss_dict['ssim_loss'].item()
            total_ssim_value += loss_dict['ssim_value'].item()
            if 'perceptual_loss' in loss_dict:
                total_perceptual += loss_dict['perceptual_loss'].item()
            num_batches += 1

            # Visualization during validation
            if (not train) and visualize_every and (i % visualize_every == 0) and vis_dir:
                os.makedirs(vis_dir, exist_ok=True)
            
                # Convert to numpy for visualization
                y_vis = y_batch[0].detach().cpu().numpy().transpose(1, 2, 0)  # CHW -> HWC
                pred_vis = pred_x[0].detach().cpu().numpy().transpose(1, 2, 0)
                x_vis = x_batch[0].detach().cpu().numpy().transpose(1, 2, 0)
            
                fig, ax = plt.subplots(1, 3, figsize=(15, 5))
                ax[0].imshow(np.clip(y_vis, 0, 1))
                ax[0].set_title("Input (112x112 Masked)")
                ax[0].axis('off')
            
                ax[1].imshow(np.clip(pred_vis, 0, 1))
                ax[1].set_title(f"Predicted (224x224)\nSSIM: {loss_dict['ssim_value'].item():.4f}")
                ax[1].axis('off')
            
                ax[2].imshow(np.clip(x_vis, 0, 1))
                ax[2].set_title("Target (224x224 Original)")
                ax[2].axis('off')
            
                plt.tight_layout()
                plt.savefig(os.path.join(vis_dir, f"val_{i:05d}.png"), dpi=120, bbox_inches='tight')
                plt.close()

    # Return averaged losses
    avg_losses = {
        'total_loss': total_loss / max(1, num_batches),
        'mse_loss': total_mse / max(1, num_batches),
        'l1_loss': total_l1 / max(1, num_batches),
        'ssim_loss': total_ssim_loss / max(1, num_batches),
        'ssim_value': total_ssim_value / max(1, num_batches),  # Higher is better
        'perceptual_loss': total_perceptual / max(1, num_batches)
    }
    return avg_losses

# def main():
#     args = {
#         # Data paths (update these to your paths)
#         'train_dir': r"F:\imgnet\data\train",
#         'val_dir': r"F:\imgnet\data\val",
#         'save_dir': "./trust_8x8_checkpoints",
#         'load_path': "",  # Path to checkpoint to resume from (optional)
    
#         # Model settings
#         'target_size': (224, 224),  # Output image size
#         'seed': 42,  # Fixed seed for orthonormal matrix A
    
#         # Training settings
#         'batch_size': 32,  # Reduced for single GPU
#         'lr': 1e-4,
#         'epochs': 100,
#         'save_every': 10,
#         'viz_every': 50,  # Visualize every N validation batches
        
#         # Loss settings - choose one of the configurations
#         'loss_config': 'basic',  # 'basic', 'ssim_focused', 'sharp_edges', 'perceptual'
#     }

#     # Set device
#     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#     print(f"🔍 Using device: {device}")

#     # Create save directory
#     os.makedirs(args['save_dir'], exist_ok=True)

#     # Save args for reference
#     import json
#     with open(os.path.join(args['save_dir'], 'training_args.json'), 'w') as f:
#         json.dump(args, f, indent=2)

#     # Create datasets
#     train_ds = PatchwiseOrthonormalDataset(
#         data_dir=args['train_dir'],
#         seed=args['seed'],
#         verbose=True
#     )

#     val_ds = PatchwiseOrthonormalDataset(
#         data_dir=args['val_dir'],
#         seed=args['seed'],  # Same seed for consistent A matrix
#         verbose=True
#     )

#     # Create dataloaders
#     train_loader = DataLoader(
#         train_ds, batch_size=args['batch_size'], shuffle=True,
#         num_workers=4, pin_memory=True, drop_last=True
#     )
#     val_loader = DataLoader(
#         val_ds, batch_size=args['batch_size'], shuffle=False,
#         num_workers=4, pin_memory=True, drop_last=False
#     )

#     # Create model
#     model = ViTUNetForInverseProblem(output_size=args['target_size']).to(device)

#     # Load checkpoint if provided
#     if args['load_path'] and os.path.isfile(args['load_path']):
#         ckpt = torch.load(args['load_path'], map_location=device)
#         state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
#         model.load_state_dict(state_dict, strict=False)
#         print(f"[GPU] Loaded checkpoint: {args['load_path']}")

#     # Create combined loss function
#     loss_configs = get_loss_configs()
#     config = loss_configs[args['loss_config']]
    
#     # Choose between basic combined loss or advanced with perceptual
#     if 'perceptual_weight' in config:
#         criterion = AdvancedCombinedLoss(**config)
#     else:
#         criterion = CombinedLoss(**config)

#     # Optimizer and scheduler
#     optimizer = torch.optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=1e-4)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer, mode='min', factor=0.5, patience=5, verbose=True
#     )

#     print(f"🚀 Starting 8x8 patch-wise orthonormal super-resolution training")
#     print(f"📊 Train: {len(train_ds)} images, Val: {len(val_ds)} images")
#     print(f"🔢 Batch size: {args['batch_size']}")
#     print(f"🎯 Task: 112x112 masked input → 224x224 original target (2x super-resolution)")
#     print(f"🎲 Loss configuration: {args['loss_config']}")

#     best_val_loss = float('inf')

#     for epoch in range(args['epochs']):
#         t0 = time.time()
    
#         # Training
#         train_losses = run_epoch(train_loader, model, optimizer, device, criterion, train=True)
    
#         # Validation
#         val_losses = run_epoch(
#             val_loader, model, optimizer, device, criterion, train=False,
#             visualize_every=args['viz_every'],
#             vis_dir=os.path.join(args['save_dir'], "val_vis")
#         )
    
#         scheduler.step(val_losses['total_loss'])

#         elapsed = time.time() - t0
#         lr = optimizer.param_groups[0]['lr']
        
#         # Enhanced logging with all loss components
#         print(f"Epoch {epoch+1:03d}/{args['epochs']:03d} | Time: {elapsed:.1f}s | LR: {lr:.2e}")
#         print(f"  Train - Total: {train_losses['total_loss']:.6f} | "
#               f"MSE: {train_losses['mse_loss']:.6f} | "
#               f"L1: {train_losses['l1_loss']:.6f} | "
#               f"SSIM: {train_losses['ssim_value']:.4f}")
#         print(f"  Val   - Total: {val_losses['total_loss']:.6f} | "
#               f"MSE: {val_losses['mse_loss']:.6f} | "
#               f"L1: {val_losses['l1_loss']:.6f} | "
#               f"SSIM: {val_losses['ssim_value']:.4f}")
        
#         if val_losses['perceptual_loss'] > 0:
#             print(f"  Perceptual - Train: {train_losses['perceptual_loss']:.6f} | "
#                   f"Val: {val_losses['perceptual_loss']:.6f}")
        
#         # Save checkpoints
#         os.makedirs(args['save_dir'], exist_ok=True)
        
#         # Save best model
#         if val_losses['total_loss'] < best_val_loss:
#             best_val_loss = val_losses['total_loss']
#             torch.save({
#                 'epoch': epoch + 1,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'val_losses': val_losses,
#                 'train_losses': train_losses,
#                 'args': args
#             }, os.path.join(args['save_dir'], "best_model.pth"))
        
#         # Save periodic checkpoints
#         if (epoch + 1) % args['save_every'] == 0:
#             torch.save({
#                 'epoch': epoch + 1,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'val_losses': val_losses,
#                 'train_losses': train_losses,
#                 'args': args
#             }, os.path.join(args['save_dir'], f"epoch_{epoch+1}.pth"))

#     print("🎉 Training completed!")


def main():
    args = {
        # Data paths (update these to your paths)
        'train_dir': r"F:\imgnet\data\train",
        'val_dir': r"F:\imgnet\data\val",
        'save_dir': "./trust_8x8_checkpoints/0920",
        'load_path': r"D:\JHU\ImageNet\trust_8x8_checkpoints\epoch_80.pth",  # Path to your checkpoint
    
        # Model settings
        'target_size': (224, 224),  # Output image size
        'seed': 42,  # Fixed seed for orthonormal matrix A
    
        # Training settings
        'batch_size': 32,  # Reduced for single GPU
        'lr': 1e-4,
        'epochs': 100,
        'save_every': 10,
        'viz_every': 50,  # Visualize every N validation batches
        
        # Loss settings - choose one of the configurations
        'loss_config': 'basic',  # 'basic', 'ssim_focused', 'sharp_edges', 'perceptual'
    }

    # Set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"🔍 Using device: {device}")

    # Create save directory
    os.makedirs(args['save_dir'], exist_ok=True)

    # Save args for reference
    import json
    with open(os.path.join(args['save_dir'], 'training_args.json'), 'w') as f:
        json.dump(args, f, indent=2)

    # Create datasets
    train_ds = PatchwiseOrthonormalDataset(
        data_dir=args['train_dir'],
        seed=args['seed'],
        verbose=True
    )

    val_ds = PatchwiseOrthonormalDataset(
        data_dir=args['val_dir'],
        seed=args['seed'],  # Same seed for consistent A matrix
        verbose=True
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_ds, batch_size=args['batch_size'], shuffle=True,
        num_workers=4, pin_memory=True, drop_last=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=args['batch_size'], shuffle=False,
        num_workers=4, pin_memory=True, drop_last=False
    )

    # Create model
    model = ViTUNetForInverseProblem(output_size=args['target_size']).to(device)

    # Create combined loss function
    loss_configs = get_loss_configs()
    config = loss_configs[args['loss_config']]
    
    # Choose between basic combined loss or advanced with perceptual
    if 'perceptual_weight' in config:
        criterion = AdvancedCombinedLoss(**config)
    else:
        criterion = CombinedLoss(**config)

    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Initialize training variables
    start_epoch = 0
    best_val_loss = float('inf')

    # Load checkpoint if provided
    if args['load_path'] and os.path.isfile(args['load_path']):
        print(f"Loading checkpoint from: {args['load_path']}")
        ckpt = torch.load(args['load_path'], map_location=device)
        
        # Load model state
        if 'model_state_dict' in ckpt:
            model.load_state_dict(ckpt['model_state_dict'], strict=False)
            print("✅ Model state loaded")
        else:
            model.load_state_dict(ckpt, strict=False)
            print("✅ Model state loaded (legacy format)")
        
        # Load optimizer state
        if 'optimizer_state_dict' in ckpt:
            optimizer.load_state_dict(ckpt['optimizer_state_dict'])
            print("✅ Optimizer state loaded")
        
        # Load scheduler state (if available)
        if 'scheduler_state_dict' in ckpt:
            scheduler.load_state_dict(ckpt['scheduler_state_dict'])
            print("✅ Scheduler state loaded")
        
        # Load epoch information
        if 'epoch' in ckpt:
            start_epoch = ckpt['epoch']
            print(f"✅ Resuming from epoch {start_epoch}")
        
        # Load best validation loss
        if 'best_val_loss' in ckpt:
            best_val_loss = ckpt['best_val_loss']
            print(f"✅ Best validation loss: {best_val_loss:.6f}")
        elif 'val_losses' in ckpt and isinstance(ckpt['val_losses'], dict):
            best_val_loss = ckpt['val_losses']['total_loss']
            print(f"✅ Best validation loss from val_losses: {best_val_loss:.6f}")
        
        print(f"🔄 Successfully loaded checkpoint: {args['load_path']}")
    else:
        print("🆕 Starting training from scratch")

    print(f"🚀 Starting 8x8 patch-wise orthonormal super-resolution training")
    print(f"📊 Train: {len(train_ds)} images, Val: {len(val_ds)} images")
    print(f"🔢 Batch size: {args['batch_size']}")
    print(f"🎯 Task: 112x112 masked input → 224x224 original target (2x super-resolution)")
    print(f"🎲 Loss configuration: {args['loss_config']}")
    print(f"📈 Starting from epoch: {start_epoch + 1}")

    for epoch in range(start_epoch, args['epochs']):
        t0 = time.time()
    
        # Training
        train_losses = run_epoch(train_loader, model, optimizer, device, criterion, train=True)
    
        # Validation
        val_losses = run_epoch(
            val_loader, model, optimizer, device, criterion, train=False,
            visualize_every=args['viz_every'],
            vis_dir=os.path.join(args['save_dir'], "val_vis")
        )
    
        scheduler.step(val_losses['total_loss'])

        elapsed = time.time() - t0
        lr = optimizer.param_groups[0]['lr']
        
        # Enhanced logging with all loss components
        print(f"Epoch {epoch+1:03d}/{args['epochs']:03d} | Time: {elapsed:.1f}s | LR: {lr:.2e}")
        print(f"  Train - Total: {train_losses['total_loss']:.6f} | "
              f"MSE: {train_losses['mse_loss']:.6f} | "
              f"L1: {train_losses['l1_loss']:.6f} | "
              f"SSIM: {train_losses['ssim_value']:.4f}")
        print(f"  Val   - Total: {val_losses['total_loss']:.6f} | "
              f"MSE: {val_losses['mse_loss']:.6f} | "
              f"L1: {val_losses['l1_loss']:.6f} | "
              f"SSIM: {val_losses['ssim_value']:.4f}")
        
        if val_losses['perceptual_loss'] > 0:
            print(f"  Perceptual - Train: {train_losses['perceptual_loss']:.6f} | "
                  f"Val: {val_losses['perceptual_loss']:.6f}")
        
        # Save checkpoints
        os.makedirs(args['save_dir'], exist_ok=True)
        
        # Save best model
        if val_losses['total_loss'] < best_val_loss:
            best_val_loss = val_losses['total_loss']
            print(f"💾 New best model! Validation loss: {best_val_loss:.6f}")
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),  # Save scheduler state
                'val_losses': val_losses,
                'train_losses': train_losses,
                'best_val_loss': best_val_loss,  # Save best val loss
                'args': args
            }, os.path.join(args['save_dir'], "best_model.pth"))
        
        # Save periodic checkpoints
        if (epoch + 1) % args['save_every'] == 0:
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),  # Save scheduler state
                'val_losses': val_losses,
                'train_losses': train_losses,
                'best_val_loss': best_val_loss,  # Save best val loss
                'args': args
            }, os.path.join(args['save_dir'], f"epoch_{epoch+1}.pth"))

    print("🎉 Training completed!")

if __name__ == "__main__":
    # First install pytorch-msssim: pip install pytorch-msssim
    torch.cuda.empty_cache()
    main()