import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import VGG19_Weights



class ContrastiveLoss(torch.nn.Module):
    def __init__(self, margin=3.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, vector1, vector2, label):
        # 计算两个向量之间的欧氏距离
        distance = torch.nn.functional.pairwise_distance(vector1, vector2, keepdim=True)
        
        # 为了数值稳定性，添加一个小的 epsilon
        epsilon = 1e-9
        distance = distance + epsilon
        
        # 计算损失
        loss_similar = (1 - label) * torch.pow(distance, 2)  # 相似样本损失
        loss_dissimilar = label * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)  # 不相似样本损失
        
        loss = torch.mean(loss_similar + loss_dissimilar)
        return loss

class SupConLoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        """
        Args:
            features: Embeddings of shape (batch_size, embedding_dim)
            labels: Labels of shape (batch_size)
        """
        device = features.device
        labels = labels.contiguous().view(-1, 1)
        batch_size = features.shape[0]

        # Mask indicating which samples are from the same class
        mask = torch.eq(labels, labels.T).float().to(device)

        # Compute cosine similarity between features
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature
        )

        # For numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # Mask out self-contrast cases
        logits_mask = torch.ones_like(mask) - torch.eye(batch_size).to(device)
        mask = mask * logits_mask

        # Compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-9)

        # Mean of log-likelihood over positive samples
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-9)

        # Loss
        loss = -mean_log_prob_pos
        loss = loss.mean()

        return loss
    

class SSIMLoss(nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIMLoss, self).__init__()
        self.window_size = window_size
        self.size_average = size_average

    def gaussian_window(self, window_size, sigma):
        x = torch.arange(window_size, dtype=torch.float32) - window_size // 2
        gauss = torch.exp(-(x ** 2) / (2 * sigma ** 2))
        gauss = gauss / gauss.sum()
        return gauss

    def create_window(self, window_size, channel):
        _1D_window = self.gaussian_window(window_size, window_size / 6).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
        return window

    def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
        mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
        mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
        sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
        sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

        C1 = 0.01 ** 2
        C2 = 0.03 ** 2

        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
                   ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

        if size_average:
            return ssim_map.mean()
        else:
            return ssim_map

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()
        window = self.create_window(self.window_size, channel).to(img1.device)
        return 1 - self._ssim(img1, img2, window, self.window_size, channel, self.size_average)


class PerceptualLoss(nn.Module):
    def __init__(self, layers=['3', '8', '17', '26']):
        super(PerceptualLoss, self).__init__()
        self.vgg = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features.eval()
        self.layers = layers
        for param in self.vgg.parameters():
            param.requires_grad = False

    def normalize(self, x):
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
        return (x - mean) / std

    def forward(self, x, y):
        x = self.normalize(x)
        y = self.normalize(y)
        x_features = self.get_features(x)
        y_features = self.get_features(y)
        loss = 0.0
        for xf, yf in zip(x_features, y_features):
            loss += nn.functional.l1_loss(xf, yf)
        return loss

    def get_features(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.layers:
                features.append(x)
        return features
    


class PhysicalLoss(nn.Module):
    def __init__(self, A):
        super(PhysicalLoss, self).__init__()
        self.register_buffer('A', A)

    def forward(self, x, y):
        assert x.shape[0] == y.shape[0], "Batch sizes do not match."
        assert x.shape[1] == y.shape[1], "Channel sizes do not match."
        batch_size, channels, *x_dims = x.shape
        _, _, *y_dims = y.shape

        x_flat = x.view(batch_size, channels, -1,1).to(torch.float64)  # Shape: (B, C, N)
        y_flat = y.view(batch_size, channels, -1,1).to(torch.float64)  # Shape: (B, C, M)
        transformed_x = torch.matmul(self.A, x_flat)

        squared_diff = (transformed_x - y_flat) ** 2
        loss = torch.sum(squared_diff) / squared_diff.numel()
        # print(squared_diff.numel())

        return loss



class EdgeLoss(nn.Module):
    def __init__(self):
        super(EdgeLoss, self).__init__()
        # 定义 Sobel 卷积核
        sobel_kernel_x = torch.tensor([[-1, 0, 1],
                                       [-2, 0, 2],
                                       [-1, 0, 1]], dtype=torch.float32)
        sobel_kernel_y = torch.tensor([[-1, -2, -1],
                                       [0, 0, 0],
                                       [1, 2, 1]], dtype=torch.float32)
        # 将卷积核调整为适用于输入通道数的形状
        self.weight_x = sobel_kernel_x.unsqueeze(0).unsqueeze(0)
        self.weight_y = sobel_kernel_y.unsqueeze(0).unsqueeze(0)

        # 冻结卷积核的参数
        self.weight_x.requires_grad = False
        self.weight_y.requires_grad = False

    def forward(self, prediction, target):
        # 假设输入为 [batch_size, channels, height, width]
        channels = prediction.size(1)
        device = prediction.device

        # 将卷积核扩展到输入的通道数
        weight_x = self.weight_x.to(device).repeat(channels, 1, 1, 1)
        weight_y = self.weight_y.to(device).repeat(channels, 1, 1, 1)

        # 计算预测的梯度
        grad_pred_x = F.conv2d(prediction, weight_x, padding=1, groups=channels)
        grad_pred_y = F.conv2d(prediction, weight_y, padding=1, groups=channels)
        grad_pred = torch.sqrt(grad_pred_x ** 2 + grad_pred_y ** 2 + 1e-6)

        # 计算目标的梯度
        grad_target_x = F.conv2d(target, weight_x, padding=1, groups=channels)
        grad_target_y = F.conv2d(target, weight_y, padding=1, groups=channels)
        grad_target = torch.sqrt(grad_target_x ** 2 + grad_target_y ** 2 + 1e-6)

        # 计算梯度的 L1 损失
        loss = F.l1_loss(grad_pred, grad_target)

        return loss
    

class TotalVariationLoss(nn.Module):
    def __init__(self):
        super(TotalVariationLoss, self).__init__()

    def forward(self, x):
        # x 的形状为 [batch_size, channels, height, width]
        batch_size = x.size(0)
        h_x = x.size(2)
        w_x = x.size(3)

        # 计算水平和垂直方向的差异
        tv_h = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2).sum()
        tv_w = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2).sum()

        # 计算总像素数
        count_h = (x.size(2) - 1) * x.size(3)
        count_w = x.size(2) * (x.size(3) - 1)

        # 归一化并返回总变分损失
        return (tv_h / count_h + tv_w / count_w) / batch_size


class RegularizationLoss(nn.Module):
    def __init__(self, model, weight_decay=1e-4, p=1):
        """
        参数：
        - model: 您的模型（包含参数）
        - weight_decay: 正则化系数(lambda), 控制正则化强度
        - p: 范数类型, 1 表示 L1 正则化, 2 表示 L2 正则化
        """
        super(RegularizationLoss, self).__init__()
        self.weight_decay = weight_decay
        self.p = p
        self.model = model

    def forward(self):
        reg_loss = 0
        for param in self.model.parameters():
            if param.requires_grad:
                reg_loss += torch.norm(param, p=self.p)
        return self.weight_decay * reg_loss

