import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import numpy as np

# Import the TransUNet model from the transunet.py file
from transunet import TransUNet, SSIMLoss
from utilies import normalize_matrix

# Import necessary components from your existing code
from vit import ResponseGTImageDataset, transform, target_transform, RidgeLoss, visualize_predictions


def setup(rank, world_size):
    """
    Setup distributed training environment
    """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
    # Set device for this process
    torch.cuda.set_device(rank)
    
    print(f"Rank {rank}: Setup complete")


def cleanup():
    """
    Clean up distributed training environment
    """
    dist.destroy_process_group()


def visualize_new(low_res, outputs, high_res, index=0, save_dir="results"):
    """
    Visualize the input, output, and ground truth images
    """
    # Check if CUDA tensors and move them to CPU
    if low_res.is_cuda:
        low_res = low_res.cpu()
    if outputs.is_cuda:
        outputs = outputs.cpu()
    if high_res.is_cuda:
        high_res = high_res.cpu()

    # Detach tensors from gradients, and convert to numpy for visualization
    low_res = low_res.detach().numpy()
    outputs = outputs.detach().numpy()
    high_res = high_res.detach().numpy()

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"prediction_{index:03d}.png")

    # Plot and save
    plt.figure(figsize=(12, 6))
    
    # Handle different tensor shapes more robustly
    if len(low_res.shape) == 4:  # [batch, channel, height, width]
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0, 0]), cmap='gray')
        plt.title('Input')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0, 0]), cmap='gray')
        plt.title('Predicted')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0, 0]), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')
    else:  # Handling alternative shapes
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0] if len(low_res.shape) > 2 else low_res), cmap='gray')
        plt.title('Input')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0] if len(outputs.shape) > 2 else outputs), cmap='gray')
        plt.title('Predicted')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0] if len(high_res.shape) > 2 else high_res), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
    return save_path  # Return the save path for logging


def create_transunet_model(img_size=224, in_channels=3, out_channels=3, device=None):
    """
    Create and initialize a TransUNet model
    
    Args:
        img_size: Input image size (default: 224)
        in_channels: Number of input channels (default: 3 for RGB)
        out_channels: Number of output channels (default: 3 for RGB)
        device: Device to move the model to
    """
    model = TransUNet(
        img_size=img_size,
        in_channels=in_channels,
        out_channels=out_channels,
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        encoder_features=[64, 128, 256, 512],
        decoder_features=[256, 128, 64, 32]
    )
    
    if device is not None:
        model = model.to(device)
        
    return model


def load_checkpoint(model, checkpoint_path, device=None):
    """
    Load model weights from checkpoint
    """
    if device is not None:
        map_location = {'cuda:%d' % 0: 'cuda:%d' % device.index}
        model.load_state_dict(torch.load(checkpoint_path, map_location=map_location))
    else:
        model.load_state_dict(torch.load(checkpoint_path))
    
    return model


def train_transunet_distributed(rank, world_size, response_dir, gt_dir, num_epochs, learning_rate, 
                              batch_size, loss_func, model_load_path, model_save_path, 
                              check_num, save_interval=5, visualize_interval=1000,
                              img_size=224, in_channels=1, out_channels=1):
    """
    Distributed training function for TransUNet
    
    Args:
        rank: Current process rank
        world_size: Total number of processes
        response_dir: Directory containing input images
        gt_dir: Directory containing ground truth images
        num_epochs: Number of training epochs
        learning_rate: Learning rate for optimizer
        batch_size: Batch size per GPU
        loss_func: Loss function type ('ridge', 'ssim', or 'mse')
        model_load_path: Path to load pre-trained model weights or 'none'
        model_save_path: Directory to save model checkpoints
        check_num: Checkpoint numbering start
        save_interval: Epoch interval for saving model
        visualize_interval: Iteration interval for visualization
        img_size: Input image size (default: 224)
        in_channels: Number of input channels (default: 1)
        out_channels: Number of output channels (default: 1)
    """
    print(f"Starting process rank: {rank}, world_size: {world_size}")
    
    # Setup the distributed environment
    setup(rank, world_size)
    
    # Set device for this process
    device = torch.device(f"cuda:{rank}")
    
    # For tracking loss and timing
    loss_history = []
    epoch_times = []
    total_start_time = time.time()
    
    # Create TransUNet model and move to device
    model = create_transunet_model(img_size=224, in_channels=3, out_channels=3, device=device)
    
    # Load model weights if specified
    if model_load_path != 'none':
        model = load_checkpoint(model, model_load_path, device)
        if rank == 0:
            print(f'Loading Path: {model_load_path}')
            print("Model weights loaded successfully.")
    
    # Wrap model with DDP
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)
    
    # Set up the dataset and distributed sampler
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
    
    # Calculate per-GPU batch size
    per_gpu_batch_size = batch_size // world_size
    if per_gpu_batch_size < 1:
        per_gpu_batch_size = 1
        if rank == 0:
            print(f"Warning: Batch size {batch_size} is smaller than world_size {world_size}. "
                  f"Setting per-GPU batch size to 1.")
    
    # Create dataloader
    dataloader = DataLoader(
        dataset, 
        batch_size=per_gpu_batch_size, 
        sampler=sampler, 
        num_workers=2,  # Reduced to avoid memory issues
        pin_memory=True
    )
    
    if rank == 0:
        print(f"Dataset size: {len(dataset)}, Dataloader batches: {len(dataloader)}")
    
    # Set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Set up loss function
    if loss_func == 'ridge':
        criterion = RidgeLoss(alpha=0.5).to(device)
        if rank == 0:
            print('Using Ridge Loss with alpha=0.5')
    elif loss_func == 'ssim':
        criterion = SSIMLoss(alpha=0.05).to(device)
        if rank == 0:
            print('Using SSIM Loss with alpha=0.05')
    elif loss_func == 'mse':
        criterion = nn.MSELoss().to(device)
        if rank == 0:
            print('Using MSE Loss')
    else:
        # Default to MSE loss
        criterion = nn.MSELoss().to(device)
        if rank == 0:
            print(f'Unknown loss function: {loss_func}, using MSE Loss')
    
    if rank == 0:
        print(f'Start training TransUNet with {loss_func} loss, lr = {learning_rate}, '
              f'total batch size = {batch_size}, per-GPU batch size = {per_gpu_batch_size}')
    
    # Training loop
    for epoch in range(num_epochs):
        # Set the epoch for the sampler
        sampler.set_epoch(epoch)
        
        model.train()
        total_loss = 0.0
        epoch_start_time = time.time()
        
        # Synchronize before starting epoch
        dist.barrier()
        
        for i, (responses, labels) in enumerate(dataloader):
            # Move data to the correct device
            responses = responses.float().to(device)
            labels = labels.float().to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(responses)
            
            # Print shapes during first iteration of first epoch
            if i == 0 and epoch == 0 and rank == 0:
                print(f"Response shape: {responses.shape}")
                print(f"Output shape: {outputs.shape}")
                print(f"Label shape: {labels.shape}")
            
            # Make sure outputs match the size of labels
            if outputs.shape != labels.shape:
                outputs = F.interpolate(outputs, size=(labels.shape[2], labels.shape[3]), 
                                       mode='bilinear', align_corners=True)
            
            # Calculate loss
            loss = criterion(outputs, labels)
            
            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            total_loss += loss.item()
            
            # Log progress
            if rank == 0 and (i + 1) % (len(dataloader) // 5 or 1) == 0:
                print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "
                      f"Loss: {loss.item():.4f}")
            
            # Memory cleanup to avoid OOM errors
            torch.cuda.empty_cache()
        
        # Calculate epoch time
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
            
        # Synchronize losses across GPUs
        loss_tensor = torch.tensor(total_loss).to(device)
        dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
        avg_loss = loss_tensor.item() / (len(dataloader) * world_size)
        
        if rank == 0:
            # Store loss and time
            loss_history.append(avg_loss)
            epoch_times.append(epoch_time)
            
            print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}, "
                  f"Time: {epoch_time:.2f}s")
            
            # Save model and visualizations periodically
            if (epoch + 1) % save_interval == 0:
                # Create visualization directory
                vis_dir = os.path.join(model_save_path, 'visualizations')
                os.makedirs(vis_dir, exist_ok=True)
                
                # Switch to evaluation mode for visualization
                model.eval()
                
                # Visualize and save output
                with torch.no_grad():
                    vis_path = visualize_new(
                        responses[:1], outputs[:1], labels[:1],
                        index=epoch+check_num, 
                        save_dir=vis_dir
                    )
                print(f"Visualization saved to {vis_path}")
                
                # Switch back to training mode
                model.train()
                
                # Save the model without DDP wrapper
                save_path = os.path.join(
                    model_save_path, 
                    f'TransUNet_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth'
                )
                torch.save(model.module.state_dict(), save_path)
                print(f"Model saved to {save_path}")
                
                # Save loss history and timing
                metrics_path = os.path.join(
                    model_save_path, 
                    f'metrics_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pt'
                )
                total_time = time.time() - total_start_time
                torch.save({
                    'loss_history': loss_history,
                    'epoch_times': epoch_times,
                    'total_time': total_time,
                    'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
                    'current_epoch': epoch + check_num
                }, metrics_path)
                print(f"Training metrics saved to {metrics_path}")
        
        # Synchronize at end of epoch
        dist.barrier()
    
    # At the end of training, save the final metrics
    if rank == 0:
        total_time = time.time() - total_start_time
        metrics_path = os.path.join(
            model_save_path, 
            f'metrics_final_lr{learning_rate}_batch{batch_size}_{loss_func}.pt'
        )
        torch.save({
            'loss_history': loss_history,
            'epoch_times': epoch_times,
            'total_time': total_time,
            'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
            'total_epochs': num_epochs
        }, metrics_path)
        print(f"Final training metrics saved to {metrics_path}")
        print(f"Total training time: {total_time:.2f}s, "
              f"Average epoch time: {sum(epoch_times) / len(epoch_times):.2f}s")
    
    # Synchronize before cleanup
    dist.barrier()
    
    # Clean up
    cleanup()


def train_transunet_single_gpu(response_dir, gt_dir, num_epochs, learning_rate, batch_size,
                             loss_func, model_load_path, model_save_path, check_num,
                             save_interval=5, visualize_interval=1000,
                             img_size=224, in_channels=1, out_channels=1):
    """
    Single GPU training function for TransUNet
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # For tracking loss and timing
    loss_history = []
    epoch_times = []
    total_start_time = time.time()
    
    # Create the TransUNet model
    model = create_transunet_model(img_size, in_channels=3, out_channels=3, device=device)
    
    # Load model if specified
    if model_load_path != 'none':
        model = load_checkpoint(model, model_load_path)
        print(f'Loading Path: {model_load_path}')
        print("Model weights loaded successfully.")
    
    # Set up the dataset and dataloader
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=4,
        pin_memory=True
    )
    
    print(f"Dataset size: {len(dataset)}, Dataloader batches: {len(dataloader)}")
    
    # Set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Set up loss function
    if loss_func == 'ridge':
        criterion = RidgeLoss(alpha=0.5).to(device)
        print('Using Ridge Loss with alpha=0.5')
    elif loss_func == 'ssim':
        criterion = SSIMLoss(alpha=0.05).to(device)
        print('Using SSIM Loss with alpha=0.05')
    elif loss_func == 'mse':
        criterion = nn.MSELoss().to(device)
        print('Using MSE Loss')
    else:
        # Default to MSE loss
        criterion = nn.MSELoss().to(device)
        print(f'Unknown loss function: {loss_func}, using MSE Loss')
    
    print(f'Start training TransUNet with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size}')
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        epoch_start_time = time.time()
        
        for i, (responses, labels) in enumerate(dataloader):
            # Move data to the correct device
            responses = responses.float().to(device)
            labels = labels.float().to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(responses)
            
            # Print shapes during first iteration of first epoch
            if i == 0 and epoch == 0:
                print(f"Response shape: {responses.shape}")
                print(f"Output shape: {outputs.shape}")
                print(f"Label shape: {labels.shape}")
            
            # Make sure outputs match the size of labels
            if outputs.shape != labels.shape:
                outputs = F.interpolate(outputs, size=(labels.shape[2], labels.shape[3]), 
                                      mode='bilinear', align_corners=True)
            
            # Calculate loss
            loss = criterion(outputs, labels)
            
            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()
            
            # Accumulate loss
            total_loss += loss.item()
            
            # Log progress
            if (i + 1) % (len(dataloader) // 5 or 1) == 0:
                print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "
                      f"Loss: {loss.item():.4f}")
            
            # Memory cleanup
            torch.cuda.empty_cache()
        
        # Calculate epoch stats
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        avg_loss = total_loss / len(dataloader)
        
        # Store loss and time
        loss_history.append(avg_loss)
        epoch_times.append(epoch_time)
        
        print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")
        
        # Save model periodically
        if (epoch + 1) % save_interval == 0:
            # Create visualization directory
            vis_dir = os.path.join(model_save_path, 'visualizations')
            os.makedirs(vis_dir, exist_ok=True)
            
            # Switch to evaluation mode
            model.eval()
            
            # Visualize and save output
            with torch.no_grad():
                vis_path = visualize_new(
                    responses[:1], outputs[:1], labels[:1],
                    index=epoch+check_num,
                    save_dir=vis_dir
                )
            print(f"Visualization saved to {vis_path}")
            
            # Save the model
            save_path = os.path.join(
                model_save_path, 
                f'TransUNet_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth'
            )
            torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")
            
            # Save loss history and timing
            metrics_path = os.path.join(
                model_save_path, 
                f'metrics_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pt'
            )
            total_time = time.time() - total_start_time
            torch.save({
                'loss_history': loss_history,
                'epoch_times': epoch_times,
                'total_time': total_time,
                'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
                'current_epoch': epoch + check_num
            }, metrics_path)
            print(f"Training metrics saved to {metrics_path}")
            
            # Switch back to training mode
            model.train()
    
    # At the end of training, save the final metrics
    total_time = time.time() - total_start_time
    metrics_path = os.path.join(
        model_save_path, 
        f'metrics_final_lr{learning_rate}_batch{batch_size}_{loss_func}.pt'
    )
    torch.save({
        'loss_history': loss_history,
        'epoch_times': epoch_times,
        'total_time': total_time,
        'average_epoch_time': sum(epoch_times) / len(epoch_times) if epoch_times else 0,
        'total_epochs': num_epochs
    }, metrics_path)
    print(f"Final training metrics saved to {metrics_path}")
    print(f"Total training time: {total_time:.2f}s, "
          f"Average epoch time: {sum(epoch_times) / len(epoch_times):.2f}s")


def main_train_transunet(response_dir, gt_dir, num_epochs, learning_rate, batch_size, 
                        loss_func, model_load_path, model_save_path, check_num, 
                        save_interval=5, visualize_interval=1000,
                        img_size=224, in_channels=1, out_channels=1):
    """
    Main function to start TransUNet training
    
    Args:
        response_dir: Directory containing input images
        gt_dir: Directory containing ground truth images
        num_epochs: Number of training epochs
        learning_rate: Learning rate for optimizer
        batch_size: Total batch size (will be divided by number of GPUs)
        loss_func: Loss function type ('ridge', 'ssim', or 'mse')
        model_load_path: Path to load pre-trained model weights or 'none'
        model_save_path: Directory to save model checkpoints
        check_num: Checkpoint numbering start
        save_interval: Epoch interval for saving model
        visualize_interval: Iteration interval for visualization
        img_size: Input image size (default: 224)
        in_channels: Number of input channels (default: 1)
        out_channels: Number of output channels (default: 1)
    """
    # Create model save directory if it doesn't exist
    os.makedirs(model_save_path, exist_ok=True)
    
    # Number of GPUs available
    world_size = torch.cuda.device_count()
    print(f"Available GPUs: {world_size}")
    
    if world_size > 1:
        print(f"Using {world_size} GPUs for distributed training")
        
        # Set environment variable for detailed debug information
        os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
        
        try:
            # Use multiprocessing to launch multiple processes
            mp.spawn(
                train_transunet_distributed,
                args=(world_size, response_dir, gt_dir, num_epochs, learning_rate, 
                     batch_size, loss_func, model_load_path, model_save_path, 
                     check_num, save_interval, visualize_interval,
                     img_size, in_channels, out_channels),
                nprocs=world_size,
                join=True
            )
        except Exception as e:
            print(f"Error during distributed training: {e}")
            print("Falling back to single GPU training")
            train_transunet_single_gpu(
                response_dir, gt_dir, num_epochs, learning_rate, batch_size,
                loss_func, model_load_path, model_save_path, check_num,
                save_interval, visualize_interval,
                img_size, in_channels, out_channels
            )
    else:
        # Fall back to single GPU training if only one GPU is available
        print("Only one GPU detected. Using single GPU training.")
        train_transunet_single_gpu(
            response_dir, gt_dir, num_epochs, learning_rate, batch_size,
            loss_func, model_load_path, model_save_path, check_num,
            save_interval, visualize_interval,
            img_size, in_channels, out_channels
        )


def evaluate_transunet(model_path, test_response_dir, test_gt_dir, output_dir=None, 
                     device=None, batch_size=8, img_size=224, in_channels=1, out_channels=1):
    """
    Evaluate a trained TransUNet model on test data
    
    Args:
        model_path: Path to the trained model weights
        test_response_dir: Directory containing test input images
        test_gt_dir: Directory containing test ground truth images
        output_dir: Directory to save evaluation results
        device: Device to run evaluation on
        batch_size: Batch size for evaluation
        img_size: Input image size (default: 224)
        in_channels: Number of input channels (default: 1)
        out_channels: Number of output channels (default: 1)
    """
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    if output_dir is None:
        output_dir = os.path.join(os.path.dirname(model_path), 'evaluation')
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Load model
    model = create_transunet_model(img_size, in_channels, out_channels, device)
    model = load_checkpoint(model, model_path, device)
    model.eval()
    
    print(f"Model loaded from {model_path}")
    
    # Set up the dataset and dataloader
    test_dataset = ResponseGTImageDataset(
        test_response_dir, 
        test_gt_dir, 
        transform=transform, 
        target_transform=target_transform
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    print(f"Test dataset size: {len(test_dataset)}, Test dataloader batches: {len(test_loader)}")
    
    # Metrics
    mse_loss = nn.MSELoss().to(device)
    ssim_loss = SSIMLoss(alpha=1.0).to(device)  # Alpha=1.0 for pure SSIM
    
    total_mse = 0.0
    total_ssim = 0.0
    total_samples = 0
    
    # Create evaluation results directory
    results_dir = os.path.join(output_dir, 'results')
    os.makedirs(results_dir, exist_ok=True)
    
    print("Starting evaluation...")
    
    with torch.no_grad():
        for i, (responses, labels) in enumerate(test_loader):
            # Move data to device
            responses = responses.float().to(device)
            labels = labels.float().to(device)
            
            # Forward pass
            outputs = model(responses)
            
            # Make sure outputs match the size of labels
            if outputs.shape != labels.shape:
                outputs = F.interpolate(outputs, size=(labels.shape[2], labels.shape[3]), 
                                     mode='bilinear', align_corners=True)
            
            # Calculate metrics
            mse = mse_loss(outputs, labels).item()
            ssim = 1 - ssim_loss(outputs, labels).item()  # 1 - SSIM loss gives SSIM score
            
            batch_size = responses.size(0)
            total_mse += mse * batch_size
            total_ssim += ssim * batch_size
            total_samples += batch_size
            
            # Progress update
            if (i + 1) % (len(test_loader) // 10 or 1) == 0:
                print(f"Evaluated {i + 1}/{len(test_loader)} batches, "
                      f"Current batch MSE: {mse:.4f}, SSIM: {ssim:.4f}")
            
            # Save first 10 results for visualization
            if i < 10:
                vis_path = visualize_new(
                    responses, outputs, labels,
                    index=i,
                    save_dir=results_dir
                )
                print(f"Saved result visualization to {vis_path}")
            
            # Memory cleanup
            torch.cuda.empty_cache()
    
    # Calculate average metrics
    avg_mse = total_mse / total_samples
    avg_ssim = total_ssim / total_samples
    psnr = 10 * np.log10(1.0 / avg_mse) if avg_mse > 0 else float('inf')
    
    # Save metrics to file
    metrics = {
        'MSE': avg_mse,
        'SSIM': avg_ssim,
        'PSNR': psnr,
        'total_samples': total_samples
    }
    
    metrics_path = os.path.join(output_dir, 'evaluation_metrics.pt')
    torch.save(metrics, metrics_path)
    
    # Print metrics
    print(f"\nEvaluation Results on {total_samples} samples:")
    print(f"MSE: {avg_mse:.4f}")
    print(f"SSIM: {avg_ssim:.4f}")
    print(f"PSNR: {psnr:.2f} dB")
    
    return metrics


if __name__ == '__main__':
    response_file_path = '/home/share/0125_dataset/01092025 test slide 1/'
    gt_file_path = '/home/share/0125_dataset/01092025 test slide 1/'

    print('TransUNet Model Training with DDP - 50 epochs')

    main_train_transunet(
        response_dir=response_file_path, 
        gt_dir=gt_file_path, 
        num_epochs=50, 
        learning_rate=1e-4, 
        batch_size=150,
        loss_func='ssim', 
        model_load_path='none',
        model_save_path='/home/dan5/optics_recon/Optics_Recon_Project/transunet_param/',
        check_num=0, 
        save_interval=5,
        visualize_interval=50,
        img_size=449,
        in_channels=3,  # Explicitly set to 3
        out_channels=3  # Explicitly set to 3
    )