import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional


def cosine_similarity(embedded_fg: torch.Tensor, embedded_bg: torch.Tensor) -> torch.Tensor:
    """Compute cosine similarity between two embeddings"""
    embedded_fg = F.normalize(embedded_fg, dim=1)
    embedded_bg = F.normalize(embedded_bg, dim=1)
    sim = torch.matmul(embedded_fg, embedded_bg.T)
    return torch.clamp(sim, min=0.0005, max=0.9995)


def cosine_distance(embedded_fg: torch.Tensor, embedded_bg: torch.Tensor) -> torch.Tensor:
    """Compute cosine distance between two embeddings"""
    embedded_fg = F.normalize(embedded_fg, dim=1)
    embedded_bg = F.normalize(embedded_bg, dim=1)
    sim = torch.matmul(embedded_fg, embedded_bg.T)
    return 1 - sim


def l2_distance(embedded_fg: torch.Tensor, embedded_bg: torch.Tensor) -> torch.Tensor:
    """Compute L2 distance between two embeddings"""
    N, C = embedded_fg.size()
    embedded_fg = embedded_fg.unsqueeze(1).expand(N, N, C)
    embedded_bg = embedded_bg.unsqueeze(0).expand(N, N, C)
    return torch.pow(embedded_fg - embedded_bg, 2).sum(2) / C


class SimMinLoss(nn.Module):
    """Similarity minimization loss - pushes representations apart"""

    def __init__(self, metric: str = 'cos', reduction: str = 'mean'):
        super(SimMinLoss, self).__init__()
        self.metric = metric
        self.reduction = reduction

    def forward(self, embedded_bg: torch.Tensor, embedded_fg: torch.Tensor) -> torch.Tensor:
        """
        Args:
            embedded_bg: [N, C] background embeddings
            embedded_fg: [N, C] foreground embeddings
        Returns:
            Similarity minimization loss
        """
        if self.metric == 'l2':
            raise NotImplementedError("L2 metric not implemented for SimMinLoss")
        elif self.metric == 'cos':
            sim = cosine_similarity(embedded_bg, embedded_fg)
            loss = -torch.log(1 - sim + 1e-8)  # Add epsilon for numerical stability
        else:
            raise ValueError(f"Unknown metric: {self.metric}")

        if self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss


class SimMaxLoss(nn.Module):
    """Similarity maximization loss - pulls similar representations together"""

    def __init__(self, metric: str = 'cos', alpha: float = 0.25, reduction: str = 'mean'):
        super(SimMaxLoss, self).__init__()
        self.metric = metric
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, embedded_bg: torch.Tensor, embedded_bg1: torch.Tensor) -> torch.Tensor:
        """
        Args:
            embedded_bg: [N, C] first set of embeddings
            embedded_bg1: [N, C] second set of embeddings
        Returns:
            Similarity maximization loss
        """
        if self.metric == 'l2':
            raise NotImplementedError("L2 metric not implemented for SimMaxLoss")
        elif self.metric == 'cos':
            sim = cosine_similarity(embedded_bg, embedded_bg1)
            loss = -torch.log(sim + 1e-8)  # Add epsilon for numerical stability
            loss = torch.clamp(loss, min=0)  # Remove negative values

            # Apply rank-based weighting
            _, indices = sim.sort(descending=True, dim=1)
            _, rank = indices.sort(dim=1)
            rank_weights = torch.exp(-rank.float() * self.alpha)
            loss = loss * rank_weights
        else:
            raise ValueError(f"Unknown metric: {self.metric}")

        if self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss


class LSGANLoss(nn.Module):
    """Least Squares GAN Loss implementation"""

    def __init__(self):
        super(LSGANLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def discriminator_loss(self, real_pred: torch.Tensor, fake_pred: torch.Tensor) -> torch.Tensor:
        """
        LSGAN discriminator loss
        L_D = 0.5 * [(D(real) - 1)^2 + D(fake)^2]
        """
        real_loss = self.mse_loss(real_pred, torch.ones_like(real_pred))
        fake_loss = self.mse_loss(fake_pred, torch.zeros_like(fake_pred))
        return 0.5 * (real_loss + fake_loss)

    def generator_loss(self, fake_pred: torch.Tensor) -> torch.Tensor:
        """
        LSGAN generator loss
        L_G = 0.5 * (D(fake) - 1)^2
        """
        return 0.5 * self.mse_loss(fake_pred, torch.ones_like(fake_pred))


class ContrastiveLoss(nn.Module):
    """Contrastive learning loss using positive and negative samples"""

    def __init__(self, temperature: float = 0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
        self.min_loss = SimMinLoss()
        self.max_loss = SimMaxLoss()

    def forward(self, source: torch.Tensor, source_emb_P: torch.Tensor,
                source_emb_N: torch.Tensor) -> torch.Tensor:
        """
        Args:
            source: Source embeddings [B, D]
            source_emb_P: Positive sample embeddings [B, D]
            source_emb_N: Negative sample embeddings [B, D]
        Returns:
            Contrastive loss
        """
        # Maximize similarity with positive samples, minimize with negative
        positive_loss = self.max_loss(source, source_emb_P)
        negative_loss = self.min_loss(source, source_emb_N)

        return positive_loss + negative_loss

    def info_nce_loss(self, source: torch.Tensor, positive: torch.Tensor,
                      negative: torch.Tensor) -> torch.Tensor:
        """Alternative InfoNCE implementation"""
        source_norm = F.normalize(source, dim=-1)
        positive_norm = F.normalize(positive, dim=-1)
        negative_norm = F.normalize(negative, dim=-1)

        # Positive similarity
        pos_sim = torch.sum(source_norm * positive_norm, dim=-1) / self.temperature

        # Negative similarities
        neg_sim = torch.matmul(source_norm, negative_norm.T) / self.temperature

        # InfoNCE loss
        logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
        labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

        return F.cross_entropy(logits, labels)


class TextureDetailLoss(nn.Module):
    """Texture detail preservation loss using Laplacian operator"""

    def __init__(self,opera = 'Sobel'):
        super(TextureDetailLoss, self).__init__()
        self.l1_loss = nn.L1Loss()
        self.op = opera

    @staticmethod
    def laplace_operator(image: torch.Tensor) -> torch.Tensor:
        """Apply Laplacian operator for edge detection"""
        device, dtype = image.device, image.dtype
        kernel = torch.tensor([[0., 1., 0.], [1., -4., 1.], [0., 1., 0.]],
                              device=device, dtype=dtype).view(1, 1, 3, 3)
        C = image.size(1)
        kernel = kernel.expand(C, 1, 3, 3)
        return F.conv2d(image, kernel, padding=1, stride=1, groups=C)

    @staticmethod
    def sobel_operator(image: torch.Tensor) -> torch.Tensor:
        """Apply Sobel operator for edge detection

        Args:
            image: Input tensor of shape [B, C, H, W]

        Returns:
            Sobel gradient magnitude of shape [B, C, H, W]
        """
        device, dtype = image.device, image.dtype

        # Sobel X kernel (detects vertical edges)
        sobel_x = torch.tensor([[-1., 0., 1.],
                                [-2., 0., 2.],
                                [-1., 0., 1.]],
                               device=device, dtype=dtype).view(1, 1, 3, 3)

        # Sobel Y kernel (detects horizontal edges)
        sobel_y = torch.tensor([[-1., -2., -1.],
                                [0., 0., 0.],
                                [1., 2., 1.]],
                               device=device, dtype=dtype).view(1, 1, 3, 3)

        C = image.size(1)
        sobel_x = sobel_x.expand(C, 1, 3, 3)
        sobel_y = sobel_y.expand(C, 1, 3, 3)

        # Compute gradients in both directions
        grad_x = F.conv2d(image, sobel_x, padding=1, stride=1, groups=C)
        grad_y = F.conv2d(image, sobel_y, padding=1, stride=1, groups=C)

        # Compute gradient magnitude
        gradient_magnitude = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-8)

        return gradient_magnitude

    # Alternative: Return both X and Y gradients separately
    @staticmethod
    def sobel_operator_xy(image: torch.Tensor):
        """Apply Sobel operator returning X and Y gradients separately

        Args:
            image: Input tensor of shape [B, C, H, W]

        Returns:
            Tuple of (grad_x, grad_y) tensors of shape [B, C, H, W]
        """
        device, dtype = image.device, image.dtype

        # Sobel kernels
        sobel_x = torch.tensor([[-1., 0., 1.],
                                [-2., 0., 2.],
                                [-1., 0., 1.]],
                               device=device, dtype=dtype).view(1, 1, 3, 3)

        sobel_y = torch.tensor([[-1., -2., -1.],
                                [0., 0., 0.],
                                [1., 2., 1.]],
                               device=device, dtype=dtype).view(1, 1, 3, 3)

        C = image.size(1)
        sobel_x = sobel_x.expand(C, 1, 3, 3)
        sobel_y = sobel_y.expand(C, 1, 3, 3)

        # Compute gradients
        grad_x = F.conv2d(image, sobel_x, padding=1, stride=1, groups=C)
        grad_y = F.conv2d(image, sobel_y, padding=1, stride=1, groups=C)

        return grad_x, grad_y

    # Alternative: Sobel with directional weighting
    @staticmethod
    def sobel_operator_weighted(image: torch.Tensor, alpha: float = 0.5) -> torch.Tensor:
        """Apply weighted Sobel operator

        Args:
            image: Input tensor of shape [B, C, H, W]
            alpha: Weight for combining X and Y gradients (0.5 = equal weight)

        Returns:
            Weighted gradient magnitude of shape [B, C, H, W]
        """
        device, dtype = image.device, image.dtype

        sobel_x = torch.tensor([[-1., 0., 1.],
                                [-2., 0., 2.],
                                [-1., 0., 1.]],
                               device=device, dtype=dtype).view(1, 1, 3, 3)

        sobel_y = torch.tensor([[-1., -2., -1.],
                                [0., 0., 0.],
                                [1., 2., 1.]],
                               device=device, dtype=dtype).view(1, 1, 3, 3)

        C = image.size(1)
        sobel_x = sobel_x.expand(C, 1, 3, 3)
        sobel_y = sobel_y.expand(C, 1, 3, 3)

        grad_x = F.conv2d(image, sobel_x, padding=1, stride=1, groups=C)
        grad_y = F.conv2d(image, sobel_y, padding=1, stride=1, groups=C)

        # Weighted combination
        gradient_magnitude = torch.sqrt(
            (alpha * grad_x) ** 2 + ((1 - alpha) * grad_y) ** 2 + 1e-8
        )

        return gradient_magnitude

    def forward(self, fused: torch.Tensor, image1: torch.Tensor,
                image2: torch.Tensor) -> torch.Tensor:
        """
        Args:
            fused: Fused image [B, C, H, W]
            image1: First input image [B, C, H, W]
            image2: Second input image [B, C, H, W]
        Returns:
            Texture detail loss
        """
        if self.op=='Sobel':
            g1 = self.sobel_operator(image1).abs()
            g2 = self.sobel_operator(image2).abs()
            gf = self.sobel_operator(fused).abs()
            target = torch.maximum(g1, g2)
            loss = self.l1_loss(gf, target)
        else:
            g1 = self.laplace_operator(image1).abs()
            g2 = self.laplace_operator(image2).abs()
            gf = self.laplace_operator(fused).abs()
            target = torch.maximum(g1, g2)
            loss = self.l1_loss(gf, target)

        return loss


class WaveletHighFrequencyLoss(nn.Module):
    """
    小波高频损失函数
    使用Haar小波变换提取高频分量，计算高频域的损失
    """

    def __init__(self, in_channels=1,opt = "high"):
        super(WaveletHighFrequencyLoss, self).__init__()
        self.in_channels = in_channels
        self.opt = opt

        # 在初始化时创建小波滤波器，避免重复创建
        self.LL, self.LH, self.HL, self.HH = self._create_wavelet_filters()
        self.l1_loss = nn.L1Loss()

    def forward(self, x, y):
        """
        前向传播计算小波高频损失

        Args:
            x: 输入张量1
            y: 输入张量2 (目标)

        Returns:
            高频域损失值
        """
        # 进行小波分解
        x_coeffs = self._wavelet_decompose(x)  # (LL, LH, HL, HH)
        y_coeffs = self._wavelet_decompose(y)



        # 如果opt== "high"，只计算高频
        if self.opt== "high":
            # 计算高频分量损失 (LH, HL, HH)
            high_freq_loss = 0.0
            high_freq_components = [1, 2, 3]  # 跳过LL (低频)分量

            for i in high_freq_components:
                high_freq_loss += torch.norm(x_coeffs[i]- y_coeffs[i],p=1)

            # 平均高频损失
            high_freq_loss = high_freq_loss / len(high_freq_components)
            total_loss = high_freq_loss

        else:
            total_loss = 0.0
            for i in range(len(x_coeffs)):
                total_loss += torch.norm(x_coeffs[i]- y_coeffs[i],p=1)
            total_loss = total_loss /len(x_coeffs)
        return total_loss

    def _create_wavelet_filters(self):
        """创建Haar小波滤波器"""
        # Haar小波基函数
        harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
        harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
        harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]

        # 2D小波滤波器 (2x2)
        harr_wav_LL = np.transpose(harr_wav_L) @ harr_wav_L  # 低-低
        harr_wav_LH = np.transpose(harr_wav_L) @ harr_wav_H  # 低-高
        harr_wav_HL = np.transpose(harr_wav_H) @ harr_wav_L  # 高-低
        harr_wav_HH = np.transpose(harr_wav_H) @ harr_wav_H  # 高-高

        # 转换为PyTorch张量
        filter_LL = torch.from_numpy(harr_wav_LL).float().unsqueeze(0).unsqueeze(0)
        filter_LH = torch.from_numpy(harr_wav_LH).float().unsqueeze(0).unsqueeze(0)
        filter_HL = torch.from_numpy(harr_wav_HL).float().unsqueeze(0).unsqueeze(0)
        filter_HH = torch.from_numpy(harr_wav_HH).float().unsqueeze(0).unsqueeze(0)

        # 创建卷积层 (使用正确的2x2核大小)
        conv_LL = nn.Conv2d(self.in_channels, self.in_channels,
                            kernel_size=2, stride=2, padding=0, bias=False,
                            groups=self.in_channels)
        conv_LH = nn.Conv2d(self.in_channels, self.in_channels,
                            kernel_size=2, stride=2, padding=0, bias=False,
                            groups=self.in_channels)
        conv_HL = nn.Conv2d(self.in_channels, self.in_channels,
                            kernel_size=2, stride=2, padding=0, bias=False,
                            groups=self.in_channels)
        conv_HH = nn.Conv2d(self.in_channels, self.in_channels,
                            kernel_size=2, stride=2, padding=0, bias=False,
                            groups=self.in_channels)

        # 设置权重并固定
        conv_LL.weight.data = filter_LL.repeat(self.in_channels, 1, 1, 1)
        conv_LH.weight.data = filter_LH.repeat(self.in_channels, 1, 1, 1)
        conv_HL.weight.data = filter_HL.repeat(self.in_channels, 1, 1, 1)
        conv_HH.weight.data = filter_HH.repeat(self.in_channels, 1, 1, 1)

        # 固定权重，不参与训练
        conv_LL.weight.requires_grad = False
        conv_LH.weight.requires_grad = False
        conv_HL.weight.requires_grad = False
        conv_HH.weight.requires_grad = False

        return conv_LL, conv_LH, conv_HL, conv_HH

    def _wavelet_decompose(self, x):
        """对输入进行小波分解"""
        # 确保滤波器在正确的设备上
        device = x.device
        if self.LL.weight.device != device:
            self.LL = self.LL.to(device)
            self.LH = self.LH.to(device)
            self.HL = self.HL.to(device)
            self.HH = self.HH.to(device)

        ll = self.LL(x)  # 低频分量
        lh = self.LH(x)  # 水平高频
        hl = self.HL(x)  # 垂直高频
        hh = self.HH(x)  # 对角高频

        return ll, lh, hl, hh


def parameter_free_ir_fusion(IR, VIS):
    """
    完全无参数的红外与可见光图像融合

    核心设计思想：
    1. 所有决策基于图像内容自适应
    2. 无需人工调参
    3. 基于信息论和统计学原理自动决策

    Args:
        IR: 红外图像 [B, C, H, W]
        VIS: 可见光图像 [B, C, H, W]

    Returns:
        w_ir_final: 红外权重mask [B, 1, H, W]
        w_vis_final: 可见光权重mask [B, 1, H, W]
        decision_confidence: 决策置信度 [B, 1, H, W]
    """
    B, C, H, W = IR.shape
    eps = 1e-8

    # 1. 灰度转换
    if C > 1:
        ir_gray = torch.mean(IR, dim=1, keepdim=True)
        vis_gray = torch.mean(VIS, dim=1, keepdim=True)
    else:
        ir_gray, vis_gray = IR, VIS

    # 2. 自适应红外显著性 (完全基于数据分布)
    def adaptive_ir_saliency(img):
        # 全局统计归一化
        img_mean = torch.mean(img, dim=[2, 3], keepdim=True)
        img_std = torch.std(img, dim=[2, 3], keepdim=True) + eps
        global_norm = (img - img_mean) / img_std

        # 自适应局部窗口 (基于图像尺寸)
        kernel_size = max(3, min(7, min(H, W) // 32)) | 1  # 确保奇数
        padding = kernel_size // 2
        ones = torch.ones(1, 1, kernel_size, kernel_size, device=img.device) / (kernel_size ** 2)

        local_mean = F.conv2d(img, ones, padding=padding)
        local_var = F.conv2d((img - local_mean) ** 2, ones, padding=padding)
        local_std = torch.sqrt(local_var + eps)

        # 局部对比度
        local_contrast = (img - local_mean) / (local_std + eps)

        # 自适应阈值基于分布
        contrast_median = torch.median(local_contrast.flatten(2), dim=-1, keepdim=True)[0].unsqueeze(-1)
        contrast_mad = torch.median(torch.abs(local_contrast.flatten(2) -
                                              contrast_median.squeeze(-1)), dim=-1, keepdim=True)[0].unsqueeze(-1)

        normalized_contrast = (local_contrast - contrast_median) / (1.4826 * contrast_mad + eps)  # MAD标准化

        # 热目标检测 (基于分位数)
        intensity_75 = torch.quantile(img.flatten(2), 0.75, dim=-1, keepdim=True).unsqueeze(-1)
        thermal_indicator = torch.sigmoid(4 * (img - intensity_75) / (img_std + eps))

        # 组合显著性 (无参数加权)
        saliency = torch.sigmoid(normalized_contrast) * thermal_indicator
        return saliency

    ir_saliency = adaptive_ir_saliency(ir_gray)

    # 3. 自适应可见光显著性 (基于纹理和边缘)
    def adaptive_vis_saliency(img):
        # Sobel边缘检测
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
                               dtype=torch.float32, device=img.device).view(1, 1, 3, 3)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
                               dtype=torch.float32, device=img.device).view(1, 1, 3, 3)

        grad_x = F.conv2d(img, sobel_x, padding=1)
        grad_y = F.conv2d(img, sobel_y, padding=1)
        gradient_mag = torch.sqrt(grad_x ** 2 + grad_y ** 2 + eps)

        # 自适应纹理评估
        texture_75 = torch.quantile(gradient_mag.flatten(2), 0.75, dim=-1, keepdim=True).unsqueeze(-1)
        texture_std = torch.std(gradient_mag, dim=[2, 3], keepdim=True) + eps

        texture_saliency = torch.sigmoid(2 * (gradient_mag - texture_75) / texture_std)
        return texture_saliency

    vis_saliency = adaptive_vis_saliency(vis_gray)

    # 4. 互补性分析 (基于互信息近似)
    def compute_complementarity(img1, img2):
        # 计算图像差异
        diff = torch.abs(img1 - img2)
        diff_norm = (diff - torch.mean(diff, dim=[2, 3], keepdim=True)) / \
                    (torch.std(diff, dim=[2, 3], keepdim=True) + eps)

        # 基于差异分布的互补性评估
        diff_entropy = torch.var(diff_norm, dim=[2, 3], keepdim=True)  # 方差作为熵的近似

        # 自适应互补权重
        complement_weight = torch.sigmoid(2 * (diff_entropy - torch.median(diff_entropy.flatten())))
        return complement_weight * torch.sigmoid(diff_norm)

    complementarity = compute_complementarity(ir_gray, vis_gray)

    # 5. 信息量评估 (基于局部方差)
    def local_information_content(img):
        # 3x3局部方差
        local_var_kernel = torch.ones(1, 1, 3, 3, device=img.device) / 9
        local_mean = F.conv2d(img, local_var_kernel, padding=1)
        local_variance = F.conv2d((img - local_mean) ** 2, local_var_kernel, padding=1)

        # 归一化信息量
        var_mean = torch.mean(local_variance, dim=[2, 3], keepdim=True)
        var_std = torch.std(local_variance, dim=[2, 3], keepdim=True) + eps
        info_content = (local_variance - var_mean) / var_std
        return torch.sigmoid(info_content)

    ir_info = local_information_content(ir_gray)
    vis_info = local_information_content(vis_gray)

    # 6. 综合评分 (无参数组合)
    # 使用几何平均而非加权和，避免参数
    ir_score = torch.pow(ir_saliency * ir_info * (1 + complementarity), 1 / 3)
    vis_score = torch.pow(vis_saliency * vis_info * (1 + complementarity), 1 / 3)

    # 7. 自适应决策 (基于局部竞争)
    score_ratio = (ir_score + eps) / (vis_score + eps)

    # 基于分位数的自适应阈值
    ratio_median = torch.median(score_ratio.flatten(2), dim=-1, keepdim=True)[0].unsqueeze(-1)

    # Softmax式的软决策，然后二值化
    soft_ir_weight = torch.sigmoid(4 * torch.log(score_ratio / ratio_median + eps))

    # 基于置信度的二值化
    decision_confidence = torch.abs(soft_ir_weight - 0.5) * 2  # [0,1]范围
    confidence_threshold = torch.quantile(decision_confidence.flatten(2), 0.5, dim=-1, keepdim=True).unsqueeze(-1)

    # 高置信度区域二值化，低置信度区域保持软权重
    high_confidence = decision_confidence > confidence_threshold
    w_ir_final = torch.where(high_confidence,
                             (soft_ir_weight > 0.5).float(),
                             soft_ir_weight)
    w_vis_final = 1 - w_ir_final

    return w_ir_final, w_vis_final
def optimal_ir_fusion(IR, VIS, k=0.1):
    """
    结合两种方法优点的最优方案

    核心设计：
    1. 采用方法2的增强热目标检测
    2. 保持方法1的简洁数学框架
    3. 平衡复杂度与性能

    Args:
        IR: 红外图像 [B, C, H, W]
        VIS: 可见光图像 [B, C, H, W]
        k: 红外偏向参数 (0.0=平衡, >0偏红外, <0偏可见光)

    Returns:
        w_ir_final: 红外二值mask [B, 1, H, W]
        w_vis_final: 可见光二值mask [B, 1, H, W]
    """
    B, C, H, W = IR.shape
    eps = 1e-8

    # 1. 灰度转换
    if C > 1:
        ir_gray = torch.mean(IR, dim=1, keepdim=True)
        vis_gray = torch.mean(VIS, dim=1, keepdim=True)
    else:
        ir_gray, vis_gray = IR, VIS

    # 2. 增强的红外显著性 (采用方法2的核心改进)
    def enhanced_ir_saliency(img):
        # 全局标准化
        img_mean = torch.mean(img, dim=[2, 3], keepdim=True)
        img_std = torch.std(img, dim=[2, 3], keepdim=True) + eps
        global_norm = (img - img_mean) / (img_std + eps)

        # 增大核捕获热目标 (方法2优点)
        kernel_size = 5
        padding = kernel_size // 2
        ones = torch.ones(1, 1, kernel_size, kernel_size, device=img.device) / (kernel_size ** 2)

        local_mean = F.conv2d(img, ones, padding=padding)
        local_var = F.conv2d((img - local_mean) ** 2, ones, padding=padding)
        local_std = torch.sqrt(local_var + eps)
        relative_contrast = local_std / (img_std + eps)

        # 热目标增强 (方法2核心)
        intensity_weight = torch.sigmoid(2 * global_norm)

        # 组合显著性
        saliency = torch.sigmoid(global_norm + relative_contrast) * intensity_weight
        return saliency

    ir_saliency = enhanced_ir_saliency(ir_gray)

    # 3. 简化的残差处理 (保持方法1的简洁性)
    residual = torch.abs(ir_gray - vis_gray)
    residual_prob = torch.sigmoid(residual)

    # 残差熵
    residual_entropy = -torch.mean(
        residual_prob * torch.log(residual_prob + eps) +
        (1 - residual_prob) * torch.log(1 - residual_prob + eps),
        dim=[2, 3], keepdim=True
    )
    lambda_complement = torch.sigmoid(residual_entropy)

    # 4. 能量函数 (采用方法1的清晰框架)
    ir_complement_bonus = lambda_complement * residual_prob * ir_saliency
    vis_complement_bonus = lambda_complement * residual_prob * (1 - ir_saliency)

    ir_energy = -(ir_saliency + ir_complement_bonus)
    vis_energy = -((1 - ir_saliency) + vis_complement_bonus)

    # 5. 带偏向的二值决策
    # k > 0 偏向红外, k < 0 偏向可见光
    energy_diff = ir_energy - vis_energy  # 负值表示IR更优
    w_ir_final = (energy_diff < k * torch.std(energy_diff)).float()
    w_vis_final = 1 - w_ir_final

    return w_ir_final, w_vis_final

class FusionLoss(nn.Module):
    """Comprehensive fusion loss combining multiple objectives"""

    def __init__(self, l1_weight: float = 1.0, ssim_weight: float = 0.5,
                 texture_weight: float = 0.5):
        super(FusionLoss, self).__init__()
        self.l1_weight = l1_weight
        self.ssim_weight = ssim_weight
        self.texture_weight = texture_weight

        self.l1_loss = nn.L1Loss()
        # self.texture_loss = TextureDetailLoss()

        # SSIM loss (requires kornia or similar library)
        try:
            import kornia
            self.ssim_loss = kornia.losses.SSIMLoss(3, reduction='mean')
        except ImportError:
            print("Warning: kornia not available, using MSE instead of SSIM")
            self.ssim_loss = nn.MSELoss()

    def forward(self, fused: torch.Tensor, ir: torch.Tensor, vis: torch.Tensor,
                w1: torch.Tensor, w2: torch.Tensor) -> torch.Tensor:
        """
        Args:
            fused: Fused image [B, C, H, W]
            ir: IR image [B, C, H, W]
            vis: VIS image [B, C, H, W]
            w1: IR attention weights [B, 1, H, W]
            w2: VIS attention weights [B, 1, H, W]
        Returns:
            Combined fusion loss
        """
        # Weighted L1 loss
        l1_loss = (self.l1_loss(ir * w1, fused * w1) +
                   self.l1_loss(vis * w2, fused * w2))

        # SSIM loss
        if hasattr(self.ssim_loss, '__name__') and 'SSIM' in str(type(self.ssim_loss)):
            # ssim_loss = self.ssim_loss(fused, torch.max(vis, ir))
            ssim_loss = self.ssim_loss(fused, vis)+self.ssim_loss(fused,  ir)
        else:
            # ssim_loss = self.ssim_loss(fused, torch.max(vis, ir))
            ssim_loss = self.ssim_loss(fused, vis)+self.ssim_loss(fused,  ir)

        # texture_loss = self.texture_loss(fused,vis,ir)
        # Combine losses
        total_loss = (self.l1_weight * l1_loss +
                      self.ssim_weight * (ssim_loss))

        return total_loss


class PerceptualLoss(nn.Module):
    """Perceptual loss using pre-trained VGG features"""

    def __init__(self, feature_layers: Optional[list] = None):
        super(PerceptualLoss, self).__init__()

        if feature_layers is None:
            feature_layers = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1']

        self.feature_layers = feature_layers
        self.l1_loss = nn.L1Loss()

        try:
            import torchvision.models as models
            vgg = models.vgg19(pretrained=True).features
            self.vgg = vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
        except ImportError:
            print("Warning: torchvision not available, perceptual loss disabled")
            self.vgg = None

    def extract_features(self, x: torch.Tensor) -> dict:
        """Extract VGG features"""
        if self.vgg is None:
            return {}

        features = {}
        layer_names = ['relu1_1', 'relu1_2', 'pool1', 'relu2_1', 'relu2_2', 'pool2',
                       'relu3_1', 'relu3_2', 'relu3_3', 'relu3_4', 'pool3',
                       'relu4_1', 'relu4_2', 'relu4_3', 'relu4_4', 'pool4',
                       'relu5_1', 'relu5_2', 'relu5_3', 'relu5_4', 'pool5']

        for i, layer in enumerate(self.vgg):
            x = layer(x)
            if i < len(layer_names) and layer_names[i] in self.feature_layers:
                features[layer_names[i]] = x

        return features

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: Predicted image [B, C, H, W]
            target: Target image [B, C, H, W]
        Returns:
            Perceptual loss
        """
        if self.vgg is None:
            return torch.tensor(0.0, device=pred.device)

        # Convert to RGB if grayscale
        if pred.size(1) == 1:
            pred = pred.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)

        # Resize to minimum VGG input size
        if pred.size(-1) < 224:
            pred = F.interpolate(pred, size=(224, 224), mode='bilinear', align_corners=False)
            target = F.interpolate(target, size=(224, 224), mode='bilinear', align_corners=False)

        pred_features = self.extract_features(pred)
        target_features = self.extract_features(target)

        loss = 0
        for layer in self.feature_layers:
            if layer in pred_features and layer in target_features:
                loss += self.l1_loss(pred_features[layer], target_features[layer])

        return loss


class GradientLoss(nn.Module):
    """Gradient preservation loss"""

    def __init__(self):
        super(GradientLoss, self).__init__()
        self.l1_loss = nn.L1Loss()

    def gradient_x(self, img: torch.Tensor) -> torch.Tensor:
        """Compute horizontal gradient"""
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
                               dtype=img.dtype, device=img.device).view(1, 1, 3, 3)
        if img.size(1) > 1:
            sobel_x = sobel_x.repeat(img.size(1), 1, 1, 1)
        return F.conv2d(img, sobel_x, padding=1, groups=img.size(1) if img.size(1) > 1 else 1)

    def gradient_y(self, img: torch.Tensor) -> torch.Tensor:
        """Compute vertical gradient"""
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
                               dtype=img.dtype, device=img.device).view(1, 1, 3, 3)
        if img.size(1) > 1:
            sobel_y = sobel_y.repeat(img.size(1), 1, 1, 1)
        return F.conv2d(img, sobel_y, padding=1, groups=img.size(1) if img.size(1) > 1 else 1)

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Args:
            pred: Predicted image [B, C, H, W]
            target: Target image [B, C, H, W]
        Returns:
            Gradient loss
        """
        pred_grad_x = self.gradient_x(pred)
        pred_grad_y = self.gradient_y(pred)
        target_grad_x = self.gradient_x(target)
        target_grad_y = self.gradient_y(target)

        loss_x = self.l1_loss(pred_grad_x, target_grad_x)
        loss_y = self.l1_loss(pred_grad_y, target_grad_y)

        return loss_x + loss_y


class TotalVariationLoss(nn.Module):
    """Total variation loss for smoothness regularization"""

    def __init__(self, reduction: str = 'mean'):
        super(TotalVariationLoss, self).__init__()
        self.reduction = reduction

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input image [B, C, H, W]
        Returns:
            Total variation loss
        """
        h_tv = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
        w_tv = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])

        if self.reduction == 'mean':
            return torch.mean(h_tv) + torch.mean(w_tv)
        elif self.reduction == 'sum':
            return torch.sum(h_tv) + torch.sum(w_tv)
        else:
            return h_tv, w_tv


# Example usage and testing
if __name__ == '__main__':
    # Test similarity losses
    fg_embedding = torch.randn((4, 64))
    bg_embedding = torch.randn((4, 64))

    sim_max = SimMaxLoss(metric='cos')
    sim_min = SimMinLoss(metric='cos')

    max_loss = sim_max(fg_embedding, bg_embedding)
    min_loss = sim_min(fg_embedding, bg_embedding)

    print(f"SimMax Loss: {max_loss.item():.4f}")
    print(f"SimMin Loss: {min_loss.item():.4f}")

    # Test fusion loss
    fusion_loss = FusionLoss()
    fused = torch.randn(2, 1, 128, 128)
    ir = torch.randn(2, 1, 128, 128)
    vis = torch.randn(2, 1, 128, 128)
    w1 = torch.rand(2, 1, 128, 128)
    w2 = 1 - w1

    loss = fusion_loss(fused, ir, vis, w1, w2)
    print(f"Fusion Loss: {loss.item():.4f}")