import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import scipy.io
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.transforms.functional import normalize
from transformers import ViTModel
import torch.optim as optim
import pytorch_ssim
from sklearn.linear_model import OrthogonalMatchingPursuit
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import scipy.io
from torchvision.transforms.functional import to_tensor
from torch.utils.data import Dataset
import math

import torch
import torch.optim as optim
from transformers import AutoModelForImageClassification
from utilies import normalize_matrix, down_sample_matrix, mat2vec, vec2mat, mask_response_circle






#import torch
from torch import nn
from transformers import ViTModel

'''
class ViTForImageReconstruction(nn.Module):
    def __init__(self, pretrained_model_name):
        super().__init__()

        #self.vit = ViTModel.from_pretrained(pretrained_model_name)

        self.vit = AutoModelForImageClassification.from_pretrained(pretrained_model_name)

        self.reshape = nn.Linear(1000, 768 * 31 * 31)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((31, 31)) 

        #self.adaptive_pool = nn.AdaptiveAvgPool2d((15,15))  # This is a strategic choice to simplify upsampling

        # Decoder to upscale to 449x449
        self.decoder = nn.Sequential(
            # Step up gradually to the nearest size that can be fine-tuned to 449
            nn.Upsample(size=(30, 30), mode='bilinear', align_corners=True),  # First upscale
            nn.Conv2d(768, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(60, 60), mode='bilinear', align_corners=True),  # Further upscale
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(120, 120), mode='bilinear', align_corners=True),  # And upscale more
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(240, 240), mode='bilinear', align_corners=True),  # Nearly there
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(449, 449), mode='bilinear', align_corners=True),  # Final size
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, pixel_values):
        features = self.vit(pixel_values=pixel_values).logits


        # Reshape features to have spatial dimensions
        features = self.reshape(features)
        features = features.view(-1, 768, 31, 31)  # Reshaping to (batch_size, channels, height, width)

        pooled_features = self.adaptive_pool(features)
        reconstructed_image = self.decoder(pooled_features)
        return reconstructed_image


'''

class ViTForImageReconstruction(nn.Module):
    def __init__(self, pretrained_model_name):
        super().__init__()
        self.vit = ViTModel.from_pretrained(pretrained_model_name)

        # Adaptive pooling to roughly reduce to a size that can be scaled up to 449x449
        self.adaptive_pool = nn.AdaptiveAvgPool2d((15, 15))  # This is a strategic choice to simplify upsampling

        # Decoder to upscale to 449x449
        self.decoder = nn.Sequential(
            # Step up gradually to the nearest size that can be fine-tuned to 449
            nn.Upsample(size=(30, 30), mode='bilinear', align_corners=True),  # First upscale
            nn.Conv2d(768, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(60, 60), mode='bilinear', align_corners=True),  # Further upscale
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(120, 120), mode='bilinear', align_corners=True),  # And upscale more
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(240, 240), mode='bilinear', align_corners=True),  # Nearly there
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(449, 449), mode='bilinear', align_corners=True),  # Final size
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, pixel_values):
        # Process through ViT
        features = self.vit(pixel_values=pixel_values).last_hidden_state

        # Rearrange features to match spatial structure
        batch_size, num_patches, hidden_dim = features.shape
        height_width = int((num_patches - 1) ** 0.5)  # Adjusted for deleting the cls token
        features = features[:, 1:, :]  # Deleting the cls token
        features = features.permute(0, 2, 1).view(batch_size, hidden_dim, height_width, height_width)

        # Pool and decode
        pooled_features = self.adaptive_pool(features)
        reconstructed_image = self.decoder(pooled_features)
        return reconstructed_image





import torch
import torch.nn.functional as F

def gaussian_window(size, sigma):
    x = torch.arange(size, dtype=torch.float32) - size // 2
    gauss = torch.exp(-(x**2) / (2 * sigma**2))
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian_window(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

def ssim(img1, img2, window_size=11, size_average=True):
    channel = img1.size(1)
    window = create_window(window_size, channel).to(img1.device)
    
    mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
    
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size//2, groups=channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    
    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIMLoss(torch.nn.Module):
    def __init__(self, window_size=9, size_average=True, alpha=1e-3):
        super(SSIMLoss, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.alpha = alpha  # Weight of the SSIM component in the loss
        self.mse_loss = nn.MSELoss()  # MSE Loss

    def forward(self, img1, img2):
        ssim_value = ssim(img1, img2, self.window_size, self.size_average)
        mse_value = self.mse_loss(img1, img2)
        return self.alpha * (1 - ssim_value) + (1 - self.alpha) * mse_value



'''
L2 + L1 Loss
'''
import torch
import torch.nn as nn

class RidgeLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(RidgeLoss, self).__init__()
        self.alpha = alpha

    def forward(self, outputs, targets):
        loss = torch.mean((outputs - targets) ** 2) + self.alpha * torch.sum(outputs ** 2)
        return loss.float()  # Ensure the loss is float32


'''
Ridge Loss End
'''


'''
DataLoader
'''


class ResponseGTImageDataset(Dataset):
    def __init__(self, response_dir, gt_dir, transform=None, target_transform=None):
        """
        Initialize dataset.
        
        Args:
        response_dir (str): Directory containing the response images.
        gt_dir (str): Directory containing the ground truth images.
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version.
        target_transform (callable, optional): A function/transform that takes in the GT image and returns a transformed version.
        """
        self.response_dir = response_dir
        self.gt_dir = gt_dir
        self.response_files = [f for f in sorted(os.listdir(response_dir)) if f.endswith('.mat')]
        self.transform = transform
        self.target_transform = target_transform


    def __len__(self):

        return int(len(self.response_files)/2)
        #return 5

    def __getitem__(self, idx):
        response_path = os.path.join(self.response_dir, f's{idx + 1}.mat')
        gt_path = os.path.join(self.gt_dir, f'gt{idx + 1}.mat')

        # Load MAT files
        mat_response = scipy.io.loadmat(response_path)
        mat_gt = scipy.io.loadmat(gt_path)

        # Inside your __getitem__ method in ResponseGTImageDataset
        mat_response_content = mat_response['img1'].astype(np.float32)
        mat_response_content_part = mat_response_content[300:1500, 480:1680]
        mask = mask_response_circle(mat_response_content_part)  # Assuming the function exists
        mat_response_content_part_masked = mat_response_content_part * mask
        mat_response_content_part_masked[mat_response_content_part_masked < 0] = 0

        # Process ground truth
        mat_gt_content = mat_gt['img0'].astype(np.float32)
        mat_gt_content[mat_gt_content < 0] = 0
        mat_gt_content_part = mat_gt_content[1070:1570, 1080:1580]

        # Normalize if needed (assuming normalization functions exist)
        mat_response_normalized = normalize_matrix(mat_response_content_part_masked)
        mat_gt_normalized = normalize_matrix(mat_gt_content_part)

        mat_response_down = down_sample_matrix(mat_response_normalized, 25)
        mat_gt_down = down_sample_matrix(mat_gt_normalized, 5)

        mat_response_down = normalize_matrix(mat_response_down)
        mat_gt_down = normalize_matrix(mat_gt_down)

        # Apply additional transforms if any
        if self.transform:
            mat_response_down = self.transform(mat_response_down)
        if self.target_transform:
            mat_gt_down = self.target_transform(mat_gt_down)

        return normalize_matrix(mat_response_down), normalize_matrix(mat_gt_down)


from torchvision import transforms

def transform_single_to_three_channel(image):
    """ Convert a single channel image to a three-channel image by repeating the channels """
    return image.repeat(3, 1, 1)

# Define the transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(transform_single_to_three_channel),
    transforms.Resize((224, 224)),  # Resize to fit the ViT model input
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ImageNet normalization
])

target_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(transform_single_to_three_channel),
    transforms.Resize((449, 449))  # Resize GT images to desired output size
])



'''
DataLoader End
'''




from torch.utils.data import DataLoader

def get_dataloader(response_dir, gt_dir, batch_size):
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def visualize_predictions(low_res, outputs, high_res):
    # Check if CUDA tensors and move them to CPU
    if low_res.is_cuda:
        low_res = low_res.cpu()
    if outputs.is_cuda:
        outputs = outputs.cpu()
    if high_res.is_cuda:
        high_res = high_res.cpu()

    # Detach tensors from gradients, and convert to numpy for visualization
    low_res = low_res.detach().squeeze().numpy()
    outputs = outputs.detach().squeeze().numpy()
    high_res = high_res.detach().squeeze().numpy()


    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(normalize_matrix(low_res[0,0]), cmap='gray')
    plt.title('Input')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(normalize_matrix(outputs[0,0]), cmap='gray')
    plt.title('Predicted')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(normalize_matrix(high_res[0,0]), cmap='gray')
    plt.title('Ground Truth')
    plt.axis('off')

    plt.show()





def train_model(response_dir, gt_dir, num_epochs, learning_rate, batch_size, loss_func, model_load_path, model_save_path, check_num, save_interval=5, visualize_interval=1000):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ViTForImageReconstruction("google/vit-base-patch16-224").to(device)
    #model = ViTForImageReconstruction("WinKawaks/vit-tiny-patch16-224").to(device)


    if model_load_path != 'none':
        print(f'Loading Path: {model_load_path}')
        model.load_state_dict(torch.load(model_load_path))
        print("Model weights loaded.")


    # DataLoader setup
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if loss_func == 'ridge':
        criterion = RidgeLoss(alpha=0.5).to(device)
        print('0.5')

    elif loss_func == 'ssim':
        criterion = SSIMLoss(alpha=0.05).to(device)

    elif loss_func == 'mse':
        criterion = nn.MSELoss().to(device)

    print(f'Start training with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size}')

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for i,(responses, labels) in enumerate(dataloader):
            responses, labels = responses.to(device), labels.to(device)
            #print('responses:')
            #print(responses.max(), responses.min())

            #print('labels:')
            #print(labels.max(), labels.min())

            optimizer.zero_grad()
            outputs = model(responses)
            #print('Predicts:')
            #print(outputs.max(), outputs.min())
            

            loss = criterion(outputs, labels.to(torch.float32))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if i % visualize_interval == 0:
                visualize_predictions(responses, outputs, labels)

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

        if (epoch + 1) % save_interval == 0:

            save_path = os.path.join(model_save_path, f'ViT_Base_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
            torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")

    return model


def visualize_new(low_res, outputs, high_res, index=0, save_dir="results"):
    # Check if CUDA tensors and move them to CPU
    if low_res.is_cuda:
        low_res = low_res.cpu()
    if outputs.is_cuda:
        outputs = outputs.cpu()
    if high_res.is_cuda:
        high_res = high_res.cpu()

    # Detach tensors from gradients, and convert to numpy for visualization
    low_res = low_res.detach().numpy()  # Keep dimensions for safer indexing
    outputs = outputs.detach().numpy()
    high_res = high_res.detach().numpy()

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"prediction_{index:03d}.png")

    # Plot and save
    plt.figure(figsize=(12, 6))
    
    # Handle different tensor shapes more robustly
    if len(low_res.shape) == 4:  # [batch, channel, height, width]
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0, 0]), cmap='gray')
        plt.title('Input')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0, 0]), cmap='gray')
        plt.title('Predicted')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0, 0]), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')
    else:  # Handling alternative shapes
        plt.subplot(1, 3, 1)
        plt.imshow(normalize_matrix(low_res[0] if len(low_res.shape) > 2 else low_res), cmap='gray')
        plt.title('Input')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(normalize_matrix(outputs[0] if len(outputs.shape) > 2 else outputs), cmap='gray')
        plt.title('Predicted')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(normalize_matrix(high_res[0] if len(high_res.shape) > 2 else high_res), cmap='gray')
        plt.title('Ground Truth')
        plt.axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
    return save_path  # Return the save path for logging



# Example usage
#response_dir = 'F:/0125_dataset/0125_res_1200/'
#gt_dir = 'F:/0125_dataset/0125_gt/'
'''

response_file_path='F:/0125_dataset/01092025 test slide 1/'
gt_file_path='F:/0125_dataset/01092025 test slide 1/'

train_model(
    response_dir=response_file_path, 
    gt_dir=gt_file_path, 
    num_epochs=200, 
    learning_rate=5e-5, 
    batch_size=5,
    loss_func='ssim', 
    model_load_path='F:/new_dictionary_param//vit_ssim_iter_100_patch200_nodown_10152024.pth', 
    model_save_path='F:/0125_dataset/ViT_Param', 
    check_num=4, 
    save_interval=100,
    visualize_interval=5)
'''
#'F:/new_dictionary_param/vit_ssim_iter_120_patch200_nodown_10152024.pth'
# tiny token: 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


'''


-----------------UNET----------------------


'''


import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """Double convolution block used in UNet"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input size might be odd, so we need to adjust the padding
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    """Final output convolution"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=3, bilinear=True, output_size=(449, 449)):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.output_size = output_size
        
        # Encoder path
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # Decoder path with skip connections
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        
        # Final convolution
        self.outc = OutConv(64, n_classes)
        
        # Optional sigmoid to ensure output is in range [0, 1]
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Ensure the input is float32
        x = x.float()
        
        # Contracting path (encoder)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Expanding path (decoder) with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        # Final convolution and activation
        x = self.outc(x)
        x = self.sigmoid(x)
        
        # Resize to match target size
        if x.size() != self.output_size:
            x = F.interpolate(x, size=self.output_size, mode='bilinear', align_corners=True)
            
        return x


def train_unet_model(response_dir, gt_dir, num_epochs, learning_rate, batch_size, 
                    loss_func, model_load_path, model_save_path, check_num, 
                    save_interval=5, visualize_interval=1000):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create model with the correct output size
    model = UNet(n_channels=3, n_classes=3, bilinear=True, output_size=(449, 449)).to(device)

    if model_load_path != 'none':
        print(f'Loading Path: {model_load_path}')
        model.load_state_dict(torch.load(model_load_path))
        print("Model weights loaded.")

    # DataLoader setup
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Check sizes from the first batch
    sample_batch = next(iter(dataloader))
    input_shape = sample_batch[0].shape
    target_shape = sample_batch[1].shape
    print(f"Input shape: {input_shape}, Target shape: {target_shape}")
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    if loss_func == 'ridge':
        criterion = RidgeLoss(alpha=0.5).to(device)
        print('Using Ridge Loss with alpha=0.5')
    elif loss_func == 'ssim':
        criterion = SSIMLoss(alpha=0.05).to(device)
        print('Using SSIM Loss with alpha=0.05')
    elif loss_func == 'mse':
        criterion = nn.MSELoss().to(device)
        print('Using MSE Loss')

    print(f'Start training with {loss_func} loss, lr = {learning_rate}, batch size = {batch_size}')

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for i, (responses, labels) in enumerate(dataloader):
            # Convert tensors to float
            responses = responses.float().to(device)
            labels = labels.float().to(device)
            
            optimizer.zero_grad()
            outputs = model(responses)
            
            # Check the sizes
            if i == 0 and epoch == 0:
                print(f"Response shape: {responses.shape}")
                print(f"Output shape: {outputs.shape}")
                print(f"Label shape: {labels.shape}")
            
            # Make sure outputs match the size of labels
            if outputs.shape != labels.shape:
                outputs = F.interpolate(outputs, size=(labels.shape[2], labels.shape[3]), 
                                        mode='bilinear', align_corners=True)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            if i % visualize_interval == 0:
                visualize_predictions(responses, outputs, labels)

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

        if (epoch + 1) % save_interval == 0:
            #save_path = os.path.join(model_save_path, f'UNet_epoch{epoch+check_num}_lr{learning_rate}_batch{batch_size}_{loss_func}.pth')
            #torch.save(model.state_dict(), save_path)
            print(f"Model saved to {save_path}")

    return model


'Unet Example'
'''
response_file_path='/home/share/0125_dataset/01092025 test slide 1/'
gt_file_path='/home/share/0125_dataset/01092025 test slide 1/'


train_unet_model(
     response_dir=response_file_path, 
     gt_dir= gt_file_path, 
     num_epochs=10, 
     learning_rate=1e-4, 
     batch_size=10,
     loss_func='ssim', 
     model_load_path='none', 
     model_save_path='/home/dan5/optics_recon/Optics_Recon_Project/unet_param',
     check_num=0, 
     save_interval=2,
     visualize_interval=50)

'''

# /home/share/0125_dataset/Good/

# /home/share/0125_dataset/01092025 test slide 1/