#!/usr/bin/env python3
import os, glob, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import math

# Install this first: pip install pytorch-msssim
from pytorch_msssim import ssim, ms_ssim

# Updated orthonormal matrix generation for 16x16 patch processing (no subsampling)
def make_gaussian_random_orthonormal_rows(h=256, w=256, seed=42):
    """
    Generate a matrix A of size [h, w] where rows are orthonormal.
    Note: Requires h <= w for orthonormal rows to exist.
    For 16x16 patches, we use 256x256 matrix (same as first code).
    """
    if seed is not None:
        torch.manual_seed(seed)
    # Step 1: Random Gaussian matrix
    A = torch.randn(h, w)
    # Step 2: QR decomposition on transpose to orthonormalize rows
    # A^T = Q*R → A = R^T * Q^T
    # We want orthonormal rows, so we work with A^T first
    Q, R = torch.linalg.qr(A.T)  # A.T is [w, h]
    # Q is [w, h] with orthonormal columns
    # Q.T is [h, w] with orthonormal rows
    return Q.T  # [h, w] with orthonormal rows

class PatchwiseOrthonormalDataset:
    """
    Dataset that applies patch-wise orthonormal transformation to images.
    Each 16x16 patch gets transformed through a 256x256 orthonormal matrix.
    NO SUBSAMPLING - keeps all 256 dimensions after transformation.
    Output: 224x224 masked input → 224x224 original target (same resolution inverse problem)
    """
    def __init__(self, data_dir, seed=42, verbose=False):
        self.data_dir = data_dir
        
        # Generate fixed orthonormal matrix A for 16x16 patch-wise transformation
        self.A = make_gaussian_random_orthonormal_rows(h=256, w=256, seed=seed)
        
        # Get all image files
        self.data_path = Path(data_dir)
        if not self.data_path.exists():
            raise FileNotFoundError(f"Data directory not found: {data_dir}")
        
        image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.JPEG', '.JPG'}
        self.image_files = [f for f in self.data_path.iterdir() 
                           if f.is_file() and f.suffix in image_extensions]
        
        if len(self.image_files) == 0:
            raise ValueError(f"No images found in {data_dir}")
        
        if verbose:
            print(f"Loaded {len(self.image_files)} images from {data_dir}")
            print(f"Using 16x16 patch-wise orthonormal transformation with matrix shape: {self.A.shape}")
            print(f"Output: 224x224 masked → 224x224 original (inverse problem, NO subsampling)")

    def __len__(self):
        return len(self.image_files)

    def resize_min_side(self, img, min_side=224):
        w, h = img.size
        s = min_side / min(w, h)
        return img.resize((int(round(w*s)), int(round(h*s))), Image.Resampling.LANCZOS)

    def center_crop(self, img, size=224):
        w, h = img.size
        left = (w - size) // 2
        top = (h - size) // 2
        return img.crop((left, top, left + size, top + size))

    def preprocess_image(self, img):
        img = img.convert("RGB")
        img_resized = self.resize_min_side(img, 224)
        img_crop = self.center_crop(img_resized, 224)
        x = np.array(img_crop).astype(np.float32) / 255.0
        return x

    def process_image_with_orthonormal_masks(self, np_img, mask_matrix):
        """
        Apply orthonormal transformation to 16x16 patches of a 224x224 image.
        
        Args:
            np_img: numpy array of shape [224, 224, 3]
            mask_matrix: torch tensor of shape [256, 256] with orthonormal rows
        
        Returns:
            transformed_patches: torch tensor of shape [14, 14, 256] (NO subsampling)
        """
        # Convert to torch and extract patches
        img_tensor = torch.from_numpy(np_img).float()
        
        # Convert to grayscale for 256 = 16*16
        if img_tensor.shape[2] == 3:
            img_gray = img_tensor.mean(dim=2)  # Convert to grayscale
        else:
            img_gray = img_tensor
        
        # Extract 16x16 patches from 224x224 image (14x14 patches total)
        patches = img_gray.unfold(0, 16, 16).unfold(1, 16, 16)  # [14, 14, 16, 16]
        
        # Flatten each patch and apply transformation
        transformed_patches = torch.zeros(14, 14, 256)  # Keep all 256 dimensions
        
        for i in range(14):
            for j in range(14):
                # Flatten 16x16 patch to 256x1
                patch_flat = patches[i, j].flatten()  # [256]
                
                # Apply orthonormal transformation: [256, 256] @ [256] → [256]
                transformed = mask_matrix @ patch_flat
                
                # NO SUBSAMPLING - keep all 256 dimensions
                transformed_patches[i, j] = transformed
        
        return transformed_patches

    def reconstruct_masked_image(self, transformed_patches):
        """
        Reconstruct 224x224 image from 14x14x256 transformed patches.
        
        Args:
            transformed_patches: torch tensor of shape [14, 14, 256]
        
        Returns:
            masked_image: torch tensor of shape [224, 224]
        """
        masked_image = torch.zeros(224, 224)
        
        for i in range(14):
            for j in range(14):
                # Get the 256-dimensional transformed patch
                transformed_patch = transformed_patches[i, j]  # [256]
                
                # Reshape back to 16x16
                patch_16x16 = transformed_patch.reshape(16, 16)
                
                # Place in the correct position in the 224x224 image
                start_h = i * 16
                end_h = start_h + 16
                start_w = j * 16
                end_w = start_w + 16
                
                masked_image[start_h:end_h, start_w:end_w] = patch_16x16
        
        return masked_image

    def apply_patchwise_orthonormal_transform(self, x):
        """
        Apply patch-wise orthonormal transformation to RGB image.
        
        Args:
            x: numpy array of shape [224, 224, 3] (original image)
        
        Returns:
            y: numpy array of shape [224, 224, 3] (transformed image - same resolution)
        """
        y_channels = []
        
        # Apply transformation to each channel separately
        for c in range(3):
            # Create single-channel image for processing
            single_channel = x[..., c]
            
            # Apply patch-wise orthonormal transformation
            transformed_patches = self.process_image_with_orthonormal_masks(
                np.expand_dims(single_channel, axis=2), self.A
            )
            
            # Reconstruct to 224x224
            masked_channel = self.reconstruct_masked_image(transformed_patches)
            
            # Store result
            y_channels.append(masked_channel.numpy())
        
        # Stack channels to create [224, 224, 3]
        y = np.stack(y_channels, axis=2)
        
        # Normalize to [0, 1] range
        y_min = y.min()
        y_max = y.max()
        y_norm = (y - y_min) / (y_max - y_min + 1e-8)
        
        return y_norm

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Warning: Could not load image {img_path}: {e}")
            img = Image.new('RGB', (224, 224), color=(0, 0, 0))
        
        # Preprocess to get x (target - original 224x224 image)
        x = self.preprocess_image(img)  # (224, 224, 3)
        
        # Apply patch-wise orthonormal transform to get y (input - 224x224 masked image)
        y = self.apply_patchwise_orthonormal_transform(x)  # (224, 224, 3)
        
        # Convert to torch tensors and change to CHW format
        x_tensor = torch.from_numpy(x).permute(2, 0, 1)  # (3, 224, 224) - target
        y_tensor = torch.from_numpy(y).permute(2, 0, 1)  # (3, 224, 224) - input
        
        return y_tensor, x_tensor  # (input 224x224, target 224x224)

# ========== TransUNet Components ==========

class PatchEmbedding(nn.Module):
    """
    Patch embedding layer for Vision Transformer.
    Divides input image into patches and projects them into embedding space.
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Patch projection: Conv2d with kernel_size=patch_size, stride=patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        # x: [B, C, H, W] → [B, embed_dim, H//patch_size, W//patch_size]
        x = self.proj(x)  # [B, embed_dim, 14, 14] for 224x224 input with patch_size=16
        B, C, H, W = x.shape
        # Flatten spatial dimensions: [B, embed_dim, H*W] → [B, H*W, embed_dim]
        x = x.flatten(2).transpose(1, 2)
        return x, (H, W)

class MultiHeadSelfAttention(nn.Module):
    """Multi-Head Self-Attention mechanism for Vision Transformer"""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        
        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each: [B, num_heads, N, head_dim]
        
        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer block with self-attention and MLP"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        # MLP
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        # Self-attention with residual connection
        x = x + self.attn(self.norm1(x))
        # MLP with residual connection
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformerEncoder(nn.Module):
    """Vision Transformer Encoder"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, 
                 depth=12, num_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        
        # Positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.n_patches, embed_dim) * 0.02)
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        # Patch embedding
        x, (H, W) = self.patch_embed(x)  # [B, N, embed_dim]
        
        # Add positional embedding
        x = x + self.pos_embed
        x = self.dropout(x)
        
        # Store intermediate features for skip connections
        features = []
        for i, block in enumerate(self.blocks):
            x = block(x)
            # Save features at specific layers for skip connections
            if i in [2, 5, 8]:  # Early, middle, late features
                features.append(x)
        
        x = self.norm(x)
        features.append(x)  # Final features
        
        return features, (H, W)

class ConvBlock(nn.Module):
    """Convolutional block with BatchNorm and ReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.conv(x)

class UpBlock(nn.Module):
    """Upsampling block for decoder"""
    def __init__(self, in_channels, out_channels, skip_channels=0):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels // 2 + skip_channels, out_channels)
        
    def forward(self, x, skip=None):
        x = self.up(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class TransUNet(nn.Module):
    """
    TransUNet: Transformer + U-Net for orthonormal inverse problem
    NO subsampling version - same as first code task formulation
    
    Input: 224x224x3 (masked/transformed image)  
    Output: 224x224x3 (reconstructed original image)
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, out_channels=3,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()
        
        # Input preprocessing - no upsampling needed, already 224x224
        self.input_prep = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Vision Transformer Encoder
        self.vit_encoder = VisionTransformerEncoder(
            img_size=img_size, patch_size=patch_size, in_channels=64,
            embed_dim=embed_dim, depth=depth, num_heads=num_heads,
            mlp_ratio=mlp_ratio, dropout=dropout
        )
        
        # CNN Encoder path (for skip connections)
        self.cnn_enc1 = ConvBlock(64, 64)
        self.cnn_enc2 = ConvBlock(64, 128)
        self.cnn_enc3 = ConvBlock(128, 256)
        self.cnn_enc4 = ConvBlock(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        
        # Feature projection from ViT to CNN decoder
        self.vit_to_cnn = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 512)
        )
        
        # CNN Decoder path
        self.dec4 = UpBlock(1024, 256, skip_channels=256)  # bottleneck[1024] → 512, +enc3[256] = 768 → 256
        self.dec3 = UpBlock(256, 128, skip_channels=128)   # dec4[256] → 128, +enc2[128] = 256 → 128  
        self.dec2 = UpBlock(128, 64, skip_channels=64)     # dec3[128] → 64, +enc1[64] = 128 → 64
        self.dec1 = UpBlock(64, 64, skip_channels=64)      # dec2[64] → 32, +enc1[64] = 96 → 64
        
        # Final output layer
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, kernel_size=1),
            nn.Sigmoid()
        )
        
        # Skip connection from input (for residual learning)
        self.skip_connection = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, kernel_size=3, padding=1),
            nn.Tanh()  # Allow both positive and negative contributions
        )
        
        print("TransUNet Architecture (No Subsampling):")
        print(f"   Input: {in_channels} channels ({img_size}x{img_size})")
        print(f"   Output: {out_channels} channels ({img_size}x{img_size})")
        print(f"   Patch size: {patch_size}x{patch_size}")
        print(f"   ViT: embed_dim={embed_dim}, depth={depth}, heads={num_heads}")
        print(f"   Task: Same resolution inverse problem (224x224 → 224x224)")
        print(f"   Skip connection: Input → Output")
        
    def forward(self, x):
        # x is [B, 3, 224, 224] - already correct size
        original_input = x
        
        # Preprocess input
        x = self.input_prep(x)  # [B, 64, 224, 224]
        
        # Create skip connection from original input
        skip_features = self.skip_connection(original_input)  # [B, 3, 224, 224]
        
        # CNN Encoder for skip connections
        enc1 = self.cnn_enc1(x)                    # [B, 64, 224, 224]
        enc2 = self.cnn_enc2(self.pool(enc1))      # [B, 128, 112, 112]
        enc3 = self.cnn_enc3(self.pool(enc2))      # [B, 256, 56, 56]
        enc4 = self.cnn_enc4(self.pool(enc3))      # [B, 512, 28, 28]
        
        # Vision Transformer Encoder
        vit_features, (H, W) = self.vit_encoder(x)  # List of features, spatial dims (14, 14)
        
        # Use the final ViT features
        final_vit_features = vit_features[-1]  # [B, 196, 768]
        B, N, C = final_vit_features.shape
        
        # Project ViT features to CNN space
        vit_proj = self.vit_to_cnn(final_vit_features)  # [B, 196, 512]
        
        # Reshape to spatial format
        vit_spatial = vit_proj.transpose(1, 2).reshape(B, 512, H, W)  # [B, 512, 14, 14]
        
        # Upsample ViT features to match enc4 spatial size
        vit_spatial = F.interpolate(vit_spatial, size=enc4.shape[-2:], 
                                  mode='bilinear', align_corners=True)  # [B, 512, 28, 28]
        
        # CNN Decoder
        # Step 1: Combine ViT and CNN features at bottleneck
        bottleneck = torch.cat([vit_spatial, enc4], dim=1)  # [B, 1024, 28, 28]
        
        # Step 2: Decoder layers
        dec4_up = self.dec4.up(bottleneck)  # [B, 512, 56, 56]
        dec4_concat = torch.cat([dec4_up, enc3], dim=1)  # [B, 768, 56, 56]
        dec4_out = self.dec4.conv(dec4_concat)  # [B, 256, 56, 56]
        
        dec3_up = self.dec3.up(dec4_out)  # [B, 128, 112, 112]
        dec3_concat = torch.cat([dec3_up, enc2], dim=1)  # [B, 256, 112, 112]
        dec3_out = self.dec3.conv(dec3_concat)  # [B, 128, 112, 112]
        
        dec2_up = self.dec2.up(dec3_out)  # [B, 64, 224, 224]
        dec2_concat = torch.cat([dec2_up, enc1], dim=1)  # [B, 128, 224, 224]
        dec2_out = self.dec2.conv(dec2_concat)  # [B, 64, 224, 224]
        
        # Final refinement layer
        dec1_up = self.dec1.up(dec2_out)  # [B, 32, 448, 448]
        dec1_up = F.interpolate(dec1_up, size=(224, 224), mode='bilinear', align_corners=True)  # [B, 32, 224, 224]
        dec1_concat = torch.cat([dec1_up, enc1], dim=1)  # [B, 96, 224, 224]
        dec1_out = self.dec1.conv(dec1_concat)  # [B, 64, 224, 224]
        
        # Final output
        output = self.final_conv(dec1_out)  # [B, 3, 224, 224]
        
        # Combine with skip connection
        output = output + skip_features  # [B, 3, 224, 224]
        
        # Ensure output is in [0, 1] range
        output = torch.clamp(output, 0, 1)
        
        return output

# Combined Loss Function
class CombinedLoss(nn.Module):
    """
    Combined loss function with MSE, L1, and SSIM terms.
    """
    def __init__(self, mse_weight=1.0, l1_weight=1.0, ssim_weight=1.0, use_ms_ssim=False):
        super().__init__()
        self.mse_weight = mse_weight
        self.l1_weight = l1_weight
        self.ssim_weight = ssim_weight
        self.use_ms_ssim = use_ms_ssim
        
        # MSE and L1 losses
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()
        
        print(f"Combined Loss - MSE: {mse_weight}, L1: {l1_weight}, SSIM: {ssim_weight}")
        if use_ms_ssim:
            print("Using Multi-Scale SSIM")
        else:
            print("Using Standard SSIM")
    
    def forward(self, pred, target):
        # Ensure inputs are in [0, 1] range for SSIM
        pred_clamped = torch.clamp(pred, 0, 1)
        target_clamped = torch.clamp(target, 0, 1)
        
        # Calculate individual losses
        mse = self.mse_loss(pred, target)
        l1 = self.l1_loss(pred, target)
        
        # Calculate SSIM loss (1 - SSIM since we want to minimize)
        if self.use_ms_ssim:
            ssim_val = ms_ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True)
        else:
            ssim_val = ssim(pred_clamped, target_clamped, data_range=1.0, size_average=True)
        
        ssim_loss = 1 - ssim_val
        
        # Combined loss
        total_loss = (self.mse_weight * mse + 
                     self.l1_weight * l1 + 
                     self.ssim_weight * ssim_loss)
        
        return {
            'total_loss': total_loss,
            'mse_loss': mse,
            'l1_loss': l1,
            'ssim_loss': ssim_loss,
            'ssim_value': ssim_val
        }

# Loss Configuration Presets
def get_loss_configs():
    """
    Different loss configurations for different training strategies.
    """
    configs = {
        # Basic combination - good starting point
        'basic': {
            'mse_weight': 1.0,
            'l1_weight': 1.0, 
            'ssim_weight': 1.0,
            'use_ms_ssim': False
        },
        
        # SSIM-focused - better perceptual quality
        'ssim_focused': {
            'mse_weight': 0.5,
            'l1_weight': 1.0,
            'ssim_weight': 2.0,
            'use_ms_ssim': True
        },
        
        # L1-focused - sharper edges
        'sharp_edges': {
            'mse_weight': 0.5,
            'l1_weight': 2.0,
            'ssim_weight': 1.0,
            'use_ms_ssim': False
        }
    }
    return configs

# Training function
def run_epoch(loader, model, optim, device, criterion, train=True, visualize_every=0, vis_dir=None):
    if train:
        model.train()
    else:
        model.eval()

    # Track individual loss components
    total_loss = 0.0
    total_mse = 0.0
    total_l1 = 0.0
    total_ssim_loss = 0.0
    total_ssim_value = 0.0
    num_batches = 0

    with torch.set_grad_enabled(train):
        for i, (y_batch, x_batch) in enumerate(loader):
            # y_batch: transformed images (input), x_batch: original images (target)
            y_batch = y_batch.to(device, non_blocking=True)
            x_batch = x_batch.to(device, non_blocking=True)

            # Forward pass: predict x from y
            pred_x = model(y_batch)
        
            # Ensure shapes match
            if pred_x.shape != x_batch.shape:
                pred_x = F.interpolate(pred_x, size=x_batch.shape[-2:], mode='bilinear', align_corners=True)

            # Compute combined loss
            loss_dict = criterion(pred_x, x_batch)
            loss = loss_dict['total_loss']
        
            if train:
                optim.zero_grad(set_to_none=True)
                loss.backward()
                
                # Gradient clipping for stable training
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optim.step()

            # Accumulate losses for monitoring
            total_loss += loss.item()
            total_mse += loss_dict['mse_loss'].item()
            total_l1 += loss_dict['l1_loss'].item()
            total_ssim_loss += loss_dict['ssim_loss'].item()
            total_ssim_value += loss_dict['ssim_value'].item()
            num_batches += 1

            # Visualization during validation
            if (not train) and visualize_every and (i % visualize_every == 0) and vis_dir:
                os.makedirs(vis_dir, exist_ok=True)
            
                # Convert to numpy for visualization
                y_vis = y_batch[0].detach().cpu().numpy().transpose(1, 2, 0)  # CHW -> HWC
                pred_vis = pred_x[0].detach().cpu().numpy().transpose(1, 2, 0)
                x_vis = x_batch[0].detach().cpu().numpy().transpose(1, 2, 0)
            
                fig, ax = plt.subplots(1, 3, figsize=(15, 5))
                ax[0].imshow(np.clip(y_vis, 0, 1))
                ax[0].set_title("Input (224x224 Masked)")
                ax[0].axis('off')
            
                ax[1].imshow(np.clip(pred_vis, 0, 1))
                ax[1].set_title(f"Predicted (224x224)\nSSIM: {loss_dict['ssim_value'].item():.4f}")
                ax[1].axis('off')
            
                ax[2].imshow(np.clip(x_vis, 0, 1))
                ax[2].set_title("Target (224x224 Original)")
                ax[2].axis('off')
            
                plt.tight_layout()
                plt.savefig(os.path.join(vis_dir, f"val_{i:05d}.png"), dpi=120, bbox_inches='tight')
                plt.close()

    # Return averaged losses
    avg_losses = {
        'total_loss': total_loss / max(1, num_batches),
        'mse_loss': total_mse / max(1, num_batches),
        'l1_loss': total_l1 / max(1, num_batches),
        'ssim_loss': total_ssim_loss / max(1, num_batches),
        'ssim_value': total_ssim_value / max(1, num_batches),  # Higher is better
    }
    return avg_losses

def main():
    args = {
        # Data paths (update these to your paths)
        'train_dir': r"F:\imgnet\data\train",
        'val_dir': r"F:\imgnet\data\val",
        'save_dir': "./transunet_no_subsample_checkpoints",
        'load_path': "",  # Path to checkpoint to resume from (optional)
    
        # Model settings
        'target_size': (224, 224),  # Output image size
        'seed': 42,  # Fixed seed for orthonormal matrix A
        
        # TransUNet specific settings
        'patch_size': 16,       # ViT patch size (same as first code)
        'embed_dim': 768,       # ViT embedding dimension
        'depth': 12,            # Number of transformer layers
        'num_heads': 12,        # Number of attention heads
        'mlp_ratio': 4,         # MLP expansion ratio
        'dropout': 0.1,         # Dropout rate
    
        # Training settings
        'batch_size': 16,       # Reduced for TransUNet (higher memory usage)
        'lr': 1e-4,
        'epochs': 100,
        'save_every': 10,
        'viz_every': 50,        # Visualize every N validation batches
        
        # Loss settings - choose one of the configurations
        'loss_config': 'basic',  # 'basic', 'ssim_focused', 'sharp_edges'
    }

    # Set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create save directory
    os.makedirs(args['save_dir'], exist_ok=True)

    # Save args for reference
    import json
    with open(os.path.join(args['save_dir'], 'training_args.json'), 'w') as f:
        json.dump(args, f, indent=2)

    # Create datasets
    train_ds = PatchwiseOrthonormalDataset(
        data_dir=args['train_dir'],
        seed=args['seed'],
        verbose=True
    )

    val_ds = PatchwiseOrthonormalDataset(
        data_dir=args['val_dir'],
        seed=args['seed'],  # Same seed for consistent A matrix
        verbose=True
    )

    # Create dataloaders
    train_loader = DataLoader(
        train_ds, batch_size=args['batch_size'], shuffle=True,
        num_workers=4, pin_memory=True, drop_last=True
    )
    val_loader = DataLoader(
        val_ds, batch_size=args['batch_size'], shuffle=False,
        num_workers=4, pin_memory=True, drop_last=False
    )

    # Create TransUNet model
    model = TransUNet(
        img_size=args['target_size'][0],
        patch_size=args['patch_size'],
        in_channels=3,
        out_channels=3,
        embed_dim=args['embed_dim'],
        depth=args['depth'],
        num_heads=args['num_heads'],
        mlp_ratio=args['mlp_ratio'],
        dropout=args['dropout']
    ).to(device)

    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"TransUNet Parameters: {total_params:,} total, {trainable_params:,} trainable")

    # Load checkpoint if provided
    if args['load_path'] and os.path.isfile(args['load_path']):
        ckpt = torch.load(args['load_path'], map_location=device)
        state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt
        model.load_state_dict(state_dict, strict=False)
        print(f"Loaded checkpoint: {args['load_path']}")

    # Create combined loss function
    loss_configs = get_loss_configs()
    config = loss_configs[args['loss_config']]
    criterion = CombinedLoss(**config)

    # Optimizer and scheduler (adjusted for TransUNet)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args['lr'], weight_decay=1e-4, 
                                 betas=(0.9, 0.999), eps=1e-8)
    
    # Cosine annealing scheduler (often works well with ViT)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-6
    )

    print(f"Starting TransUNet training for orthonormal inverse problem (No Subsampling)")
    print(f"Train: {len(train_ds)} images, Val: {len(val_ds)} images")
    print(f"Batch size: {args['batch_size']}")
    print(f"Task: 224x224 masked input → 224x224 original target (same resolution)")
    print(f"Loss configuration: {args['loss_config']}")

    best_val_loss = float('inf')

    for epoch in range(args['epochs']):
        t0 = time.time()
    
        # Training
        train_losses = run_epoch(train_loader, model, optimizer, device, criterion, train=True)
    
        # Validation
        val_losses = run_epoch(
            val_loader, model, optimizer, device, criterion, train=False,
            visualize_every=args['viz_every'],
            vis_dir=os.path.join(args['save_dir'], "val_vis")
        )
    
        scheduler.step()

        elapsed = time.time() - t0
        lr = optimizer.param_groups[0]['lr']
        
        # Enhanced logging with all loss components
        print(f"Epoch {epoch+1:03d}/{args['epochs']:03d} | Time: {elapsed:.1f}s | LR: {lr:.2e}")
        print(f"  Train - Total: {train_losses['total_loss']:.6f} | "
              f"MSE: {train_losses['mse_loss']:.6f} | "
              f"L1: {train_losses['l1_loss']:.6f} | "
              f"SSIM: {train_losses['ssim_value']:.4f}")
        print(f"  Val   - Total: {val_losses['total_loss']:.6f} | "
              f"MSE: {val_losses['mse_loss']:.6f} | "
              f"L1: {val_losses['l1_loss']:.6f} | "
              f"SSIM: {val_losses['ssim_value']:.4f}")
        
        # Save checkpoints
        os.makedirs(args['save_dir'], exist_ok=True)
        
        # Save best model
        if val_losses['total_loss'] < best_val_loss:
            best_val_loss = val_losses['total_loss']
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_losses': val_losses,
                'train_losses': train_losses,
                'args': args
            }, os.path.join(args['save_dir'], "best_model.pth"))
        
        # Save periodic checkpoints
        if (epoch + 1) % args['save_every'] == 0:
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_losses': val_losses,
                'train_losses': train_losses,
                'args': args
            }, os.path.join(args['save_dir'], f"epoch_{epoch+1}.pth"))

    print("TransUNet training completed!")

if __name__ == "__main__":
    # First install pytorch-msssim: pip install pytorch-msssim
    torch.cuda.empty_cache()
    main()