import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import scipy.io
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.transforms.functional import normalize
from transformers import ViTModel
import torch.optim as optim
import pytorch_ssim
from sklearn.linear_model import OrthogonalMatchingPursuit
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import scipy.io
from torchvision.transforms.functional import to_tensor
from torch.utils.data import Dataset
import math

import torch
import torch.optim as optim
from transformers import AutoModelForImageClassification
from utilies import normalize_matrix, down_sample_matrix, mat2vec, vec2mat, mask_response_circle






#import torch
from torch import nn
from transformers import ViTModel

'''
class ViTForImageReconstruction(nn.Module):
    def __init__(self, pretrained_model_name):
        super().__init__()

        #self.vit = ViTModel.from_pretrained(pretrained_model_name)

        self.vit = AutoModelForImageClassification.from_pretrained(pretrained_model_name)

        self.reshape = nn.Linear(1000, 768 * 31 * 31)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((31, 31)) 

        #self.adaptive_pool = nn.AdaptiveAvgPool2d((15,15))  # This is a strategic choice to simplify upsampling

        # Decoder to upscale to 449x449
        self.decoder = nn.Sequential(
            # Step up gradually to the nearest size that can be fine-tuned to 449
            nn.Upsample(size=(30, 30), mode='bilinear', align_corners=True),  # First upscale
            nn.Conv2d(768, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(60, 60), mode='bilinear', align_corners=True),  # Further upscale
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(120, 120), mode='bilinear', align_corners=True),  # And upscale more
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(240, 240), mode='bilinear', align_corners=True),  # Nearly there
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(449, 449), mode='bilinear', align_corners=True),  # Final size
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, pixel_values):
        features = self.vit(pixel_values=pixel_values).logits


        # Reshape features to have spatial dimensions
        features = self.reshape(features)
        features = features.view(-1, 768, 31, 31)  # Reshaping to (batch_size, channels, height, width)

        pooled_features = self.adaptive_pool(features)
        reconstructed_image = self.decoder(pooled_features)
        return reconstructed_image


'''

class ViTForImageReconstruction(nn.Module):
    def __init__(self, pretrained_model_name):
        super().__init__()
        self.vit = ViTModel.from_pretrained(pretrained_model_name)

        # Adaptive pooling to roughly reduce to a size that can be scaled up to 449x449
        self.adaptive_pool = nn.AdaptiveAvgPool2d((15, 15))  # This is a strategic choice to simplify upsampling

        # Decoder to upscale to 449x449
        self.decoder = nn.Sequential(
            # Step up gradually to the nearest size that can be fine-tuned to 449
            nn.Upsample(size=(30, 30), mode='bilinear', align_corners=True),  # First upscale
            nn.Conv2d(768, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(60, 60), mode='bilinear', align_corners=True),  # Further upscale
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(120, 120), mode='bilinear', align_corners=True),  # And upscale more
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(240, 240), mode='bilinear', align_corners=True),  # Nearly there
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(size=(449, 449), mode='bilinear', align_corners=True),  # Final size
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, pixel_values):
        # Process through ViT
        features = self.vit(pixel_values=pixel_values).last_hidden_state

        # Rearrange features to match spatial structure
        batch_size, num_patches, hidden_dim = features.shape
        height_width = int((num_patches - 1) ** 0.5)  # Adjusted for deleting the cls token
        features = features[:, 1:, :]  # Deleting the cls token
        features = features.permute(0, 2, 1).view(batch_size, hidden_dim, height_width, height_width)

        # Pool and decode
        pooled_features = self.adaptive_pool(features)
        reconstructed_image = self.decoder(pooled_features)
        return reconstructed_image





import torch
import torch.nn.functional as F

def gaussian_window(size, sigma):
    x = torch.arange(size, dtype=torch.float32) - size // 2
    gauss = torch.exp(-(x**2) / (2 * sigma**2))
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian_window(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window

def ssim(img1, img2, window_size=11, size_average=True):
    channel = img1.size(1)
    window = create_window(window_size, channel).to(img1.device)
    
    mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
    
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size//2, groups=channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    
    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIMLoss(torch.nn.Module):
    def __init__(self, window_size=9, size_average=True, alpha=1e-3):
        super(SSIMLoss, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.alpha = alpha  # Weight of the SSIM component in the loss
        self.mse_loss = nn.MSELoss()  # MSE Loss

    def forward(self, img1, img2):
        ssim_value = ssim(img1, img2, self.window_size, self.size_average)
        mse_value = self.mse_loss(img1, img2)
        return self.alpha * (1 - ssim_value) + (1 - self.alpha) * mse_value



'''
L2 + L1 Loss
'''
import torch
import torch.nn as nn

class RidgeLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(RidgeLoss, self).__init__()
        self.alpha = alpha

    def forward(self, outputs, targets):
        loss = torch.mean((outputs - targets) ** 2) + self.alpha * torch.sum(outputs ** 2)
        return loss.float()  # Ensure the loss is float32


'''
Ridge Loss End
'''


'''
DataLoader
'''


class ResponseGTImageDataset(Dataset):
    def __init__(self, response_dir, gt_dir, transform=None, target_transform=None):
        """
        Initialize dataset.
        
        Args:
        response_dir (str): Directory containing the response images.
        gt_dir (str): Directory containing the ground truth images.
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version.
        target_transform (callable, optional): A function/transform that takes in the GT image and returns a transformed version.
        """
        self.response_dir = response_dir
        self.gt_dir = gt_dir
        self.response_files = [f for f in sorted(os.listdir(response_dir)) if f.endswith('.mat')]
        self.transform = transform
        self.target_transform = target_transform


    def __len__(self):

        return int(len(self.response_files)/2)
        #return 5

    def __getitem__(self, idx):
        response_path = os.path.join(self.response_dir, f's{idx + 1}.mat')
        gt_path = os.path.join(self.gt_dir, f'gt{idx + 1}.mat')

        # Load MAT files
        mat_response = scipy.io.loadmat(response_path)
        mat_gt = scipy.io.loadmat(gt_path)

        # Inside your __getitem__ method in ResponseGTImageDataset
        mat_response_content = mat_response['img1'].astype(np.float32)
        mat_response_content_part = mat_response_content[300:1500, 480:1680]
        mask = mask_response_circle(mat_response_content_part)  # Assuming the function exists
        mat_response_content_part_masked = mat_response_content_part * mask
        mat_response_content_part_masked[mat_response_content_part_masked < 0] = 0

        # Process ground truth
        mat_gt_content = mat_gt['img0'].astype(np.float32)
        mat_gt_content[mat_gt_content < 0] = 0
        mat_gt_content_part = mat_gt_content[1070:1570, 1080:1580]

        # Normalize if needed (assuming normalization functions exist)
        mat_response_normalized = normalize_matrix(mat_response_content_part_masked)
        mat_gt_normalized = normalize_matrix(mat_gt_content_part)

        mat_response_down = down_sample_matrix(mat_response_normalized, 25)
        mat_gt_down = down_sample_matrix(mat_gt_normalized, 5)

        mat_response_down = normalize_matrix(mat_response_down)
        mat_gt_down = normalize_matrix(mat_gt_down)

        # Apply additional transforms if any
        if self.transform:
            mat_response_down = self.transform(mat_response_down)
        if self.target_transform:
            mat_gt_down = self.target_transform(mat_gt_down)

        return normalize_matrix(mat_response_down), normalize_matrix(mat_gt_down)


from torchvision import transforms

def transform_single_to_three_channel(image):
    """ Convert a single channel image to a three-channel image by repeating the channels """
    return image.repeat(3, 1, 1)

# Define the transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(transform_single_to_three_channel),
    transforms.Resize((224, 224)),  # Resize to fit the ViT model input
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ImageNet normalization
])

target_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(transform_single_to_three_channel),
    transforms.Resize((449, 449))  # Resize GT images to desired output size
])



'''
DataLoader End
'''




from torch.utils.data import DataLoader

def get_dataloader(response_dir, gt_dir, batch_size):
    dataset = ResponseGTImageDataset(response_dir, gt_dir, transform=transform, target_transform=target_transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def visualize_predictions(low_res, outputs, high_res):
    # Check if CUDA tensors and move them to CPU
    if low_res.is_cuda:
        low_res = low_res.cpu()
    if outputs.is_cuda:
        outputs = outputs.cpu()
    if high_res.is_cuda:
        high_res = high_res.cpu()

    # Detach tensors from gradients, and convert to numpy for visualization
    low_res = low_res.detach().squeeze().numpy()
    outputs = outputs.detach().squeeze().numpy()
    high_res = high_res.detach().squeeze().numpy()


    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(normalize_matrix(low_res[0,0]), cmap='gray')
    plt.title('Input')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(normalize_matrix(outputs[0,0]), cmap='gray')
    plt.title('Predicted')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(normalize_matrix(high_res[0,0]), cmap='gray')
    plt.title('Ground Truth')
    plt.axis('off')

    plt.show()