import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import numpy as np
from einops import rearrange
from torch.cuda.amp import autocast, GradScaler
import os

# Import necessary components from your existing implementation
from vit import ResponseGTImageDataset, transform, target_transform, SSIMLoss, RidgeLoss
from utilies import normalize_matrix

# --------------------------------------------------------------------------------
# Restormer Components for Optics Reconstruction
# --------------------------------------------------------------------------------

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super(LayerNorm, self).__init__()
        self.body = nn.LayerNorm(dim)

    def forward(self, x):
        # Apply LayerNorm for channel dimension
        h, w = x.shape[-2:]
        return self.body(x.flatten(2).transpose(1, 2)).transpose(1, 2).reshape(-1, x.shape[1], h, w)


class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()
        
        hidden_features = int(dim * ffn_expansion_factor)
        
        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
        
    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        # GELU-like activation (approximation used in original implementation)
        x = x1 * torch.sigmoid(x2)
        x = self.project_out(x)
        return x


class MDTA(nn.Module):
    """Multi-Dconv-Head Transposed Attention"""
    def __init__(self, dim, num_heads, bias):
        super(MDTA, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        
    def forward(self, x):
        b, c, h, w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)
        
        # Reshape to perform attention across channels instead of spatial dimensions
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        # Transpose for attention dot product
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        # Reshape back to image format
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        out = self.project_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type='WithBias'):
        super(TransformerBlock, self).__init__()
        
        self.norm1 = LayerNorm(dim)
        self.attn = MDTA(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        # First attention block with residual connection
        x = x + self.attn(self.norm1(x))
        # Second feed-forward block with residual connection
        x = x + self.ffn(self.norm2(x))
        return x


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
        )

    def forward(self, x):
        return self.body(x)


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels*4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(2)
        )

    def forward(self, x):
        return self.body(x)


class RestormerOptics(nn.Module):
    def __init__(
        self, 
        inp_channels=3,  # Modified for RGB images
        out_channels=3,  # Modified for RGB images
        dim=48,
        num_blocks=[4, 6, 6, 8], 
        num_heads=[1, 2, 4, 8],
        ffn_expansion_factor=2.66,
        bias=False
    ):
        super(RestormerOptics, self).__init__()
        
        self.patch_embed = nn.Conv2d(inp_channels, dim, kernel_size=3, stride=1, padding=1, bias=bias)

        # Encoder
        self.encoder_level1 = nn.Sequential(*[
            TransformerBlock(dim=dim, num_heads=num_heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias) 
            for _ in range(num_blocks[0])
        ])
        
        self.down1_2 = DownSample(dim, 2*dim)
        self.encoder_level2 = nn.Sequential(*[
            TransformerBlock(dim=2*dim, num_heads=num_heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias) 
            for _ in range(num_blocks[1])
        ])
        
        self.down2_3 = DownSample(2*dim, 4*dim)
        self.encoder_level3 = nn.Sequential(*[
            TransformerBlock(dim=4*dim, num_heads=num_heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias) 
            for _ in range(num_blocks[2])
        ])
        
        self.down3_4 = DownSample(4*dim, 8*dim)
        self.latent = nn.Sequential(*[
            TransformerBlock(dim=8*dim, num_heads=num_heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias) 
            for _ in range(num_blocks[3])
        ])
        
        # Decoder
        self.up4_3 = UpSample(8*dim, 4*dim)
        self.reduce_chan_level3 = nn.Conv2d(8*dim, 4*dim, kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[
            TransformerBlock(dim=4*dim, num_heads=num_heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias) 
            for _ in range(num_blocks[2])
        ])
        
        self.up3_2 = UpSample(4*dim, 2*dim)
        self.reduce_chan_level2 = nn.Conv2d(4*dim, 2*dim, kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[
            TransformerBlock(dim=2*dim, num_heads=num_heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias) 
            for _ in range(num_blocks[1])
        ])
        
        self.up2_1 = UpSample(2*dim, dim)
        self.reduce_chan_level1 = nn.Conv2d(2*dim, dim, kernel_size=1, bias=bias)
        self.decoder_level1 = nn.Sequential(*[
            TransformerBlock(dim=dim, num_heads=num_heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias) 
            for _ in range(num_blocks[0])
        ])
        
        self.output = nn.Conv2d(dim, out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
        self.output_activation = nn.Sigmoid()  # Normalize to [0,1]
        
    def forward(self, x):
        # Initial feature extraction
        inp_features = self.patch_embed(x)
        
        # Encoder
        out_enc_level1 = self.encoder_level1(inp_features)
        
        out_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(out_enc_level2)
        
        out_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(out_enc_level3)
        
        out_enc_level4 = self.down3_4(out_enc_level3)
        latent = self.latent(out_enc_level4)
        
        # Decoder
        out_dec_level3 = self.up4_3(latent)
        out_dec_level3 = torch.cat([out_dec_level3, out_enc_level3], 1)
        out_dec_level3 = self.reduce_chan_level3(out_dec_level3)
        out_dec_level3 = self.decoder_level3(out_dec_level3)
        
        out_dec_level2 = self.up3_2(out_dec_level3)
        out_dec_level2 = torch.cat([out_dec_level2, out_enc_level2], 1)
        out_dec_level2 = self.reduce_chan_level2(out_dec_level2)
        out_dec_level2 = self.decoder_level2(out_dec_level2)
        
        out_dec_level1 = self.up2_1(out_dec_level2)
        out_dec_level1 = torch.cat([out_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.reduce_chan_level1(out_dec_level1)
        out_dec_level1 = self.decoder_level1(out_dec_level1)
        
        # Output projection
        output = self.output(out_dec_level1)
        
        # Normalize outputs to [0,1] with sigmoid
        output = self.output_activation(output)
        
        return output


def setup(rank, world_size):
    """
    Setup distributed training
    """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    """
    Clean up distributed training
    """
    dist.destroy_process_group()


def visualize_new(low_res, outputs, high_res, index=0, save_dir="results"):
    # Check if CUDA tensors and move them to CPU
    if low_res.is_cuda:
        low_res = low_res.cpu()
    if outputs.is_cuda:
        outputs = outputs.cpu()
    if high_res.is_cuda:
        high_res = high_res.cpu()

    # Detach tensors from gradients, and convert to numpy for visualization
    low_res = low_res.detach().numpy()
    outputs = outputs.detach().numpy()
    high_res = high_res.detach().numpy()

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"prediction_{index:03d}.png")

    # Plot and save
    plt.figure(figsize=(12, 6))
    
    # Handle different tensor shapes more robustly
    if len(low_res.shape) == 4:  # [batch, channel, height, width]
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0, 0]), cmap='gray')
        plt.title('Input (Response)')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0, 0]), cmap='gray')
        plt.title('Predicted (Restormer)')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0, 0]), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')
    else:  # Handling alternative shapes
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0] if len(low_res.shape) > 2 else low_res), cmap='gray')
        plt.title('Input (Response)')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0] if len(outputs.shape) > 2 else outputs), cmap='gray')
        plt.title('Predicted (Restormer)')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0] if len(high_res.shape) > 2 else high_res), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
    return save_path  # Return the save path for logging


def train_restormer_optics_distributed(rank, world_size, response_dir, gt_dir, num_epochs, learning_rate, 
                                     batch_size, loss_func, model_load_path, model_save_path, 
                                     check_num, save_interval=5, visualize_interval=1000):
    """
    Distributed training function for Optics reconstruction using Restormer model.
    
    Args:
        rank: Current GPU rank
        world_size: Total number of GPUs
        response_dir: Directory with response files
        gt_dir: Directory with ground truth files
        num_epochs: Number of epochs to train
        learning_rate: Learning rate for optimizer
        batch_size: Total batch size (will be divided by world_size)
        loss_func: Loss function ('ridge', 'ssim', or 'mse')
        model_load_path: Path to checkpoint for resuming training ('none' if starting fresh)
        model_save_path: Directory to save model checkpoints
        check_num: Epoch number to continue from
        save_interval: Save model checkpoint every N epochs
        visualize_interval: Visualize results every N batches
    """
    # Setup the distributed environment
    setup(rank, world_size)
    
    # For tracking metrics
    loss_history = []
    epoch_times = []
    total_start_time = time.time()
    
    # Set device for this process
    device = torch.device(f"cuda:{rank}")
    
    # Create model and move it to GPU
    model = RestormerOptics(
        inp_channels=3,  # For RGB images, change if needed
        out_channels=3,  # For RGB images, change if needed
        dim=48,          # Base dimension
        num_blocks=[4, 6, 6, 8],  # Number of transformer blocks at each level
        num_heads=[1, 2, 4, 8],   # Number of attention heads at each level
        ffn_expansion_factor=2.66,
        bias=False
    ).to(device)
    
    # Initialize optimizer and scaler for mixed precision
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scaler = GradScaler()
    
    # Load model if specified (for resuming training)
    if model_load_path != 'none':
        try:
            # Make sure to load to the correct device
            map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
            model.load_state_dict(torch.load(model_load_path, map_location=map_location))
            if rank == 0:
                print(f'Loading Path: {model_load_path}')
                print("Model weights loaded.")
        except Exception as e:
            if rank == 0:
                print(f"Error loading checkpoint: {e}")
                print("Starting training from scratch.")
    
    # Wrap model with DDP after loading weights
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
    
    # Set up the dataset and distributed sampler
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    
    # Calculate per-GPU batch size
    batch_size_per_gpu = batch_size // world_size
    if batch_size_per_gpu < 1:
        batch_size_per_gpu = 1
        if rank == 0:
            print(f"Warning: Adjusted batch size to 1 per GPU (total: {world_size})")
    
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size_per_gpu, 
        sampler=sampler, 
        num_workers=4,
        pin_memory=True
    )
    
    # Set up loss function
    if loss_func == 'ridge':
        criterion = RidgeLoss(alpha=0.5).to(device)
        if rank == 0:
            print('Using Ridge Loss with alpha=0.5')
    elif loss_func == 'ssim':
        criterion = SSIMLoss(alpha=0.05).to(device)
        if rank == 0:
            print('Using SSIM Loss with alpha=0.05')
    elif loss_func == 'mse':
        criterion = nn.MSELoss().to(device)
        if rank == 0:
            print('Using MSE Loss')
    
    # Optional: Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=(rank == 0)
    )
    
    if rank == 0:
        print(f'Start training Restormer with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size} (per GPU: {batch_size_per_gpu})')
        print(f'Total epochs: {num_epochs}, Save interval: {save_interval}')
    
    for epoch in range(num_epochs):
        # Set the epoch for the sampler
        sampler.set_epoch(epoch)
        
        model.train()
        total_loss = 0
        epoch_start_time = time.time()
        
        for i, (responses, labels) in enumerate(dataloader):
            # Move data to the correct device
            responses = responses.float().to(device, non_blocking=True)
            labels = labels.float().to(device, non_blocking=True)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            with autocast():
                # Forward pass
                outputs = model(responses)
                
                # Only print shapes on the first GPU during first iteration of first epoch
                if i == 0 and epoch == 0 and rank == 0:
                    print(f"Response shape: {responses.shape}")
                    print(f"Output shape: {outputs.shape}")
                    print(f"Label shape: {labels.shape}")
                
                # Ensure outputs match the size of labels
                if outputs.shape != labels.shape:
                    outputs = F.interpolate(
                        outputs, 
                        size=(labels.shape[2], labels.shape[3]), 
                        mode='bilinear', 
                        align_corners=True
                    )
                
                # Calculate loss
                loss = criterion(outputs, labels)
            
            # Backward pass and optimizer step with scaling for mixed precision
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            
            # Memory cleanup to avoid OOM errors
            if i % 10 == 0:
                torch.cuda.empty_cache()
                
            # Visualization during training (only on rank 0)
            if rank == 0 and i % visualize_interval == 0 and i > 0:
                with torch.no_grad():
                    vis_dir = os.path.join(model_save_path, 'visualizations')
                    os.makedirs(vis_dir, exist_ok=True)
                    vis_path = visualize_new(
                        responses[:1], outputs[:1], labels[:1],
                        index=epoch*1000+i+check_num,
                        save_dir=vis_dir
                    )
                    print(f"Epoch {epoch+1}, Batch {i} visualization saved to {vis_path}")
        
        # Calculate epoch time
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
            
        # Synchronize losses across GPUs
        loss_tensor = torch.tensor(total_loss).to(device)
        dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
        avg_loss = loss_tensor.item() / (len(dataloader) * world_size)
        
        # Update scheduler based on average loss
        if rank == 0:
            # Store metrics
            loss_history.append(avg_loss)
            epoch_times.append(epoch_time)
            
            print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")
            
            # Update learning rate based on validation loss (we're using training loss here)
            scheduler.step(avg_loss)
        
        # Make sure all processes see the updated learning rate
        dist.barrier()
        
        # Save model periodically (only on the first GPU)
        if rank == 0 and (epoch + 1) % save_interval == 0:
            # Create the visualization directory if it doesn't exist
            vis_dir = os.path.join(model_save_path, 'visualizations')
            os.makedirs(vis_dir, exist_ok=True)
            
            # Visualize and save output
            with torch.no_grad():
                model.eval()  # Set to evaluation mode
                for responses, labels in dataloader:
                    responses = responses.float().to(device)
                    labels = labels.float().to(device)
                    outputs = model(responses)
                    
                    # Make sure outputs match the size of labels
                    if outputs.shape != labels.shape:
                        outputs = F.interpolate(
                            outputs, 
                            size=(labels.shape[2], labels.shape[3]), 
                            mode='bilinear', 
                            align_corners=True
                        )
                    
                    vis_path = visualize_new(
                        responses[:1], outputs[:1], labels[:1],
                        index=epoch+check_num,
                        save_dir=vis_dir
                    )
                    print(f"Epoch {epoch+1} visualization saved to {vis_path}")
                    break  # Only use one batch for visualization
                    
                model.train()  # Set back to training mode
            
            # Save model checkpoint
            save_path = os.path.join(model_save_path, f'Restormer_epoch{epoch+1+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
            torch.save(model.module.state_dict(), save_path)
            print(f"Model saved to {save_path}")
            
            # Save loss history and timing
            metrics_path = os.path.join(model_save_path, f'metrics_epoch{epoch+1+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
            total_time = time.time() - total_start_time
            torch.save({
                'loss_history': loss_history,
                'epoch_times': epoch_times,
                'total_time': total_time,
                'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
                'current_epoch': epoch + 1 + check_num,
                'learning_rate': optimizer.param_groups[0]['lr']
            }, metrics_path)
            print(f"Training metrics saved to {metrics_path}")
            
            # Plot loss curve
            plt.figure(figsize=(10, 5))
            plt.plot(range(1, len(loss_history) + 1), loss_history)
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training Loss')
            plt.grid(True)
            plt.savefig(os.path.join(model_save_path, f'loss_curve_epoch{epoch+1+check_num}.png'))
            plt.close()
            
            # Memory cleanup
            torch.cuda.empty_cache()
    
    # At the end of training, save the final metrics and model
    if rank == 0:
        total_time = time.time() - total_start_time
        
        # Save final model
        final_model_path = os.path.join(model_save_path, f'Restormer_final_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
        torch.save(model.module.state_dict(), final_model_path)
        print(f"Final model saved to {final_model_path}")
        
        # Save final metrics
        final_metrics_path = os.path.join(model_save_path, f'metrics_final_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
        torch.save({
            'loss_history': loss_history,
            'epoch_times': epoch_times,
            'total_time': total_time,
            'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
            'total_epochs': num_epochs,
            'learning_rate': optimizer.param_groups[0]['lr']
        }, final_metrics_path)
        print(f"Final training metrics saved to {final_metrics_path}")
        print(f"Total training time: {total_time:.2f}s, Average epoch time: {sum(epoch_times) / len(epoch_times):.2f}s")
    
    # Clean up distributed environment
    cleanup()


def main_train_restormer_optics(response_dir, gt_dir, num_epochs, learning_rate, batch_size, 
                              loss_func, model_load_path, model_save_path, check_num, 
                              save_interval=5, visualize_interval=1000):
    """
    Main function to start Restormer training for optics reconstruction
    
    Args:
        response_dir: Directory with response files
        gt_dir: Directory with ground truth files
        num_epochs: Number of epochs to train
        learning_rate: Learning rate for optimizer
        batch_size: Total batch size
        loss_func: Loss function ('ridge', 'ssim', or 'mse')
        model_load_path: Path to checkpoint for resuming training ('none' if starting fresh)
        model_save_path: Directory to save model checkpoints
        check_num: Epoch number to continue from
        save_interval: Save model checkpoint every N epochs
        visualize_interval: Visualize results every N batches
    """
    # Create model save directory if it doesn't exist
    os.makedirs(model_save_path, exist_ok=True)
    
    # Number of GPUs available
    world_size = torch.cuda.device_count()
    print(f"Using {world_size} GPUs for training")
    
    if world_size > 1:
        # Set environment variable for detailed debug information
        os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
        
        # Use multiprocessing to launch multiple processes
        mp.spawn(
            train_restormer_optics_distributed,
            args=(world_size, response_dir, gt_dir, num_epochs, learning_rate, 
                 batch_size, loss_func, model_load_path, model_save_path, 
                 check_num, save_interval, visualize_interval),
            nprocs=world_size,
            join=True
        )
    else:
        # Fall back to single GPU training if only one GPU is available
        print("Only one GPU detected. Using single GPU training.")
        
        # Define a simple wrapper for single-GPU training
        device = torch.device("cuda:0")
        model = RestormerOptics(
            inp_channels=3,  # For RGB images
            out_channels=3,  # For RGB images
            dim=48,
            num_blocks=[4, 6, 6, 8],
            num_heads=[1, 2, 4, 8],
            ffn_expansion_factor=2.66,
            bias=False
        ).to(device)
        
        if model_load_path != 'none':
            model.load_state_dict(torch.load(model_load_path))
            print(f'Loading Path: {model_load_path}')
            print("Model weights loaded.")
        
        # Set up the dataset and loader
        dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
        dataloader = DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=True, 
            num_workers=4,
            pin_memory=True
        )
        
        # Set up optimizer, loss function, and scaler for mixed precision
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        scaler = GradScaler()
        
        if loss_func == 'ridge':
            criterion = RidgeLoss(alpha=0.5).to(device)
            print('Using Ridge Loss with alpha=0.5')
        elif loss_func == 'ssim':
            criterion = SSIMLoss(alpha=0.05).to(device)
            print('Using SSIM Loss with alpha=0.05')
        elif loss_func == 'mse':
            criterion = nn.MSELoss().to(device)
            print('Using MSE Loss')
        
        # Optional: Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )
        
        loss_history = []
        epoch_times = []
        total_start_time = time.time()
        
        print(f'Start training Restormer with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size}')
        print(f'Total epochs: {num_epochs}, Save interval: {save_interval}')
        
        # Training loop
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0
            epoch_start_time = time.time()
            
            for i, (responses, labels) in enumerate(dataloader):
                responses = responses.float().to(device, non_blocking=True)
                labels = labels.float().to(device, non_blocking=True)
                
                # Zero gradients
                optimizer.zero_grad()
                
                # Forward pass with mixed precision
                with autocast():
                    outputs = model(responses)
                    
                    if i == 0 and epoch == 0:
                        print(f"Response shape: {responses.shape}")
                        print(f"Output shape: {outputs.shape}")
                        print(f"Label shape: {labels.shape}")
                    
                    # Ensure outputs match the size of labels
                    if outputs.shape != labels.shape:
                        outputs = F.interpolate(
                            outputs, 
                            size=(labels.shape[2], labels.shape[3]), 
                            mode='bilinear', 
                            align_corners=True
                        )
                    
                    # Calculate loss
                    loss = criterion(outputs, labels)
                
                # Backward pass and optimizer step with scaling
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
                total_loss += loss.item()
                
                # Memory cleanup
                if i % 10 == 0:
                    torch.cuda.empty_cache()
                
                # Visualization during training
                if i % visualize_interval == 0 and i > 0:
                    with torch.no_grad():
                        vis_dir = os.path.join(model_save_path, 'visualizations')
                        os.makedirs(vis_dir, exist_ok=True)
                        vis_path = visualize_new(
                            responses[:1], outputs[:1], labels[:1],
                            index=epoch*1000+i+check_num,
                            save_dir=vis_dir
                        )
                        print(f"Epoch {epoch+1}, Batch {i} visualization saved to {vis_path}")
            
            # Calculate epoch time and average loss
            epoch_end_time = time.time()
            epoch_time = epoch_end_time - epoch_start_time
            avg_loss = total_loss / len(dataloader)
            
            # Store metrics
            loss_history.append(avg_loss)
            epoch_times.append(epoch_time)
            
            print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")
            
            # Update learning rate based on loss
            scheduler.step(avg_loss)
            
            # Save model and visualize periodically
            if (epoch + 1) % save_interval == 0:
                vis_dir = os.path.join(model_save_path, 'visualizations')
                os.makedirs(vis_dir, exist_ok=True)
                
                with torch.no_grad():
                    model.eval()  # Set to evaluation mode
                    for responses, labels in dataloader:
                        responses = responses.float().to(device)
                        labels = labels.float().to(device)
                        outputs = model(responses)
                        
                        if outputs.shape != labels.shape:
                            outputs = F.interpolate(
                                outputs, 
                                size=(labels.shape[2], labels.shape[3]), 
                                mode='bilinear', 
                                align_corners=True
                            )
                        
                        vis_path = visualize_new(
                            responses[:1], outputs[:1], labels[:1],
                            index=epoch+check_num,
                            save_dir=vis_dir
                        )
                        print(f"Epoch {epoch+1} visualization saved to {vis_path}")
                        break  # Only use one batch for visualization
                    
                    model.train()  # Set back to training mode
                
                # Save model checkpoint
                save_path = os.path.join(model_save_path, f'Restormer_epoch{epoch+1+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
                torch.save(model.state_dict(), save_path)
                print(f"Model saved to {save_path}")
                
                # Save metrics
                metrics_path = os.path.join(model_save_path, f'metrics_epoch{epoch+1+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
                total_time = time.time() - total_start_time
                torch.save({
                    'loss_history': loss_history,
                    'epoch_times': epoch_times,
                    'total_time': total_time,
                    'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
                    'current_epoch': epoch + 1 + check_num,
                    'learning_rate': optimizer.param_groups[0]['lr']
                }, metrics_path)
                print(f"Training metrics saved to {metrics_path}")
                
                # Plot loss curve
                plt.figure(figsize=(10, 5))
                plt.plot(range(1, len(loss_history) + 1), loss_history)
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.title('Training Loss')
                plt.grid(True)
                plt.savefig(os.path.join(model_save_path, f'loss_curve_epoch{epoch+1+check_num}.png'))
                plt.close()
                
                # Memory cleanup
                torch.cuda.empty_cache()
        
        # Save final model and metrics
        total_time = time.time() - total_start_time
        
        # Save final model
        final_model_path = os.path.join(model_save_path, f'Restormer_final_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
        torch.save(model.state_dict(), final_model_path)
        print(f"Final model saved to {final_model_path}")
        
        # Save final metrics
        final_metrics_path = os.path.join(model_save_path, f'metrics_final_lr{learning_rate}_batch{batch_size}_{loss_func}.pt')
        torch.save({
            'loss_history': loss_history,
            'epoch_times': epoch_times,
            'total_time': total_time,
            'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
            'total_epochs': num_epochs,
            'learning_rate': optimizer.param_groups[0]['lr']
        }, final_metrics_path)
        print(f"Final training metrics saved to {final_metrics_path}")
        print(f"Total training time: {total_time:.2f}s, Average epoch time: {sum(epoch_times) / len(epoch_times):.2f}s")




if __name__ == '__main__':
    # Example usage:
    response_file_path = '/home/share/0125_dataset/01092025 test slide 1/'
    gt_file_path = '/home/share/0125_dataset/01092025 test slide 1/'

    # Clear CUDA cache
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    print('GPU Clean')
    print('Restormer Model Training for Optics with DDP')

    main_train_restormer_optics(
        response_dir=response_file_path, 
        gt_dir=gt_file_path, 
        num_epochs=50, 
        learning_rate=1e-4, 
        batch_size=16,  # Adjust based on your GPU memory
        loss_func='ssim', 
        model_load_path='/home/dan5/optics_recon/Optics_Recon_Project/restormer_optics_param/Restormer_epoch15_lr0.0001_batch16_ssim.pth',  # Set to your pretrained model path if available
        model_save_path='/home/dan5/optics_recon/Optics_Recon_Project/restormer_optics_param/',
        check_num=15, 
        save_interval=5,
        visualize_interval=1000
    )