import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
import numpy as np
import warnings
from typing import List, Tuple, Optional


class Encoder(nn.Module):
    """Basic encoder for single modality feature extraction"""

    def __init__(self, input_channels: int = 1):
        super(Encoder, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, 3, 1, 1),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, 3, 1, 1),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, 3, 1, 1),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(48, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, 3, 1, 1),
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        E1 = self.conv1(x)
        E2 = self.conv2(E1)
        E3 = self.conv3(torch.cat([E2, E1], 1))
        E4 = self.conv4(torch.cat([E3, E2, E1], 1))
        features = torch.cat([E4, E3, E2, E1], 1)  # [B, 64, H, W]
        return features, [E4, E3, E2, E1]


class DualEncoder(nn.Module):
    """Dual encoder for IR and VIS modalities"""

    def __init__(self, input_channels: int = 1):
        super(DualEncoder, self).__init__()
        self.encoder_ir = Encoder(input_channels)
        self.encoder_vis = Encoder(input_channels)

    def forward(self, ir: torch.Tensor, vis: torch.Tensor) -> Tuple[
        torch.Tensor, List[torch.Tensor], torch.Tensor, List[torch.Tensor]]:
        F_ir, emb_ir = self.encoder_ir(ir)
        F_vis, emb_vis = self.encoder_vis(vis)
        return F_ir, emb_ir, F_vis, emb_vis


class Decoder(nn.Module):
    """Decoder for generating fused images"""

    def __init__(self, in_channels: int = 64):
        super(Decoder, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, 1, 1),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 16, 3, 1, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, 3, 1, 1),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(16, 1, 3, 1, 1),
            nn.Tanh()

        )
        self.adin1 =  AdaptiveInstanceNormalization(in_channels,in_channels)
        self.adin2 =  AdaptiveInstanceNormalization(in_channels,32)
        self.adin3 =  AdaptiveInstanceNormalization(in_channels,16)

    def forward(self, x: torch.Tensor, style_features: Optional[List[torch.Tensor]],ir_sty: Optional[List[torch.Tensor]],vis_sty: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        B,C,H,W = style_features.shape
        ir_sty = ir_sty.unsqueeze(-1).unsqueeze(-1).expand(B,C,H,W)
        vis_sty =vis_sty.unsqueeze(-1).unsqueeze(-1).expand(B,C,H,W)

        feat0 = x
        if style_features is not None and len(style_features) >= 2:
            feat0 = self.adin1(feat0, style_features,ir_sty,vis_sty) + feat0

        feat1 = self.conv1(feat0)
        if style_features is not None and len(style_features) >= 2:
            feat1 =self.adin2(feat1,  style_features,ir_sty,vis_sty) + feat1

        feat2 = self.conv2(feat1)
        if style_features is not None and len(style_features) >= 2:
            feat2 = self.adin3(feat2, style_features,ir_sty,vis_sty) + feat2

        feat3 = self.conv3(feat2)
        feat3 =  feat3  / 2 + 0.5
        return feat3


class MemoryGuidedFusion(nn.Module):
    """
    Memory-guided fusion module
    """

    def __init__(self, feat_dim, clip_dim=512, num_heads=8):
        super().__init__()
        self.feat_dim = feat_dim
        self.modality_selector = nn.MultiheadAttention(
            embed_dim=feat_dim,
            num_heads=num_heads,
            batch_first=True
        )
        self.clip_projector = nn.Sequential(
            nn.Linear(clip_dim, feat_dim),
            nn.LayerNorm(feat_dim)
        )

    def forward(self, ir_repr, vis_repr, memory_repr):
        B, C = ir_repr.shape[:2]

        modality_bank = torch.stack([ir_repr, vis_repr], dim=1)  # [B, 2, C, 1, 1]

        memory_query = memory_repr.squeeze(-1).squeeze(-1).unsqueeze(1)  # [B, 1, C]
        modality_kv = modality_bank.squeeze(-1).squeeze(-1)  # [B, N, C]
        fused_features, attention_weights = self.modality_selector(
            query=memory_query, 
            key=modality_kv, 
            value=modality_kv 
        )
        fused_repr = fused_features.squeeze(1).unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]
        return fused_repr, attention_weights

class AdaptiveInstanceNormalization(nn.Module):
    """Adaptive Instance Normalization module for style transfer"""
    def __init__(self, feature_indim, feature_outdim,eps: float = 1e-5):
        super(AdaptiveInstanceNormalization, self).__init__()
        self.eps = eps
        self.feat_projection =  nn.Sequential(
            nn.Conv2d(feature_indim, feature_outdim, 3, 1, 1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(feature_outdim, feature_outdim, 1, 1, 0)
        )
        self.feat_projection_ir =  nn.Sequential(
            nn.Conv2d(feature_indim, feature_outdim, 3, 1, 1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(feature_outdim, feature_outdim, 1, 1, 0)
        )
        self.feat_projection_vis =  nn.Sequential(
            nn.Conv2d(feature_indim, feature_outdim, 3, 1, 1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(feature_outdim, feature_outdim, 1, 1, 0)
        )
        self.embfuse = MemoryGuidedFusion(feature_outdim)

        self.conv = nn.Sequential(
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(feature_outdim, feature_outdim, 1, 1, 0)
        )
    def calc_mean_std(self, feat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Calculate mean and std for AdaIN"""
        size = feat.size()
        assert (len(size) == 4)
        N, C = size[:2]
        feat_var = feat.view(N, C, -1)
        mean = feat_var.mean(dim=2, keepdim=True)
        squared_mean = (feat_var ** 2).mean(dim=2, keepdim=True)
        feat_var = squared_mean - mean ** 2 + 1e-5
        feat_std = feat_var.sqrt().view(N, C, 1, 1)
        feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
        return feat_mean, feat_std

    def forward(self, content_feat, style_feat_in,ir,vis):
        """Apply Adaptive Instance Normalization

        Args:
            content_feat: Content features [B, C, H, W]
            style_feat: Style features [B, C, H, W] or [B, C]

        Returns:
            Normalized features [B, C, H, W]
        """
        style_feat = self.fused_sty(style_feat_in,ir,vis)
        # Handle both spatial and non-spatial style features
        style_mean, style_std = self.calc_mean_std(style_feat)

        content_mean, content_std = self.calc_mean_std(content_feat)

        size = content_feat.size()
        normalized_feat = (content_feat - content_mean.expand(
            size)) / content_std.expand(size)

        return normalized_feat * style_std.expand(size) + style_mean.expand(size)

    def fused_sty(self, style_feat_in,ir,vis):

        style_feat = self.feat_projection(style_feat_in)
        ir = self.feat_projection_ir(ir)
        vis = self.feat_projection_vis(vis)
        fused_sty, _ = self.embfuse(ir,vis,style_feat)
        return fused_sty


class MemoryDrivenFusion(nn.Module):
    """Memory-driven fusion module with dynamic thresholding"""

    def __init__(self, in_channels: int, memory_size: int = 100, feature_dim: int = 64,
                 momentum: float = 0.1, threshold_strategy: str = 'adaptive',
                 initial_threshold: float = 0.5):
        super(MemoryDrivenFusion, self).__init__()

        self.encoder = nn.Conv2d(in_channels * 2, feature_dim, kernel_size=3, padding=1)
        self.decoder = nn.Conv2d(feature_dim, in_channels, kernel_size=3, padding=1)

        self.memory_size = memory_size
        self.feature_dim = feature_dim
        self.momentum = momentum

        # Memory bank
        self.register_buffer('memory_bank', torch.randn(memory_size, feature_dim))
        self.memory_bank = F.normalize(self.memory_bank, dim=1)

        # Feature processing
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.feat_projection = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feature_dim, feature_dim)
        )

        # Multi-level processing
        self.nlp1 = nn.Sequential(
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(32, 32, 1, 1, 0)
        )
        self.nlp2 = nn.Sequential(
            nn.Conv2d(32, 16, 3, 1, 1),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(16, 16, 1, 1, 0)
        )

        self.adin = AdaptiveInstanceNormalization(feature_dim,feature_dim)
        # Dynamic thresholding parameters
        self.threshold_strategy = threshold_strategy
        self.initial_threshold = initial_threshold
        self.register_buffer('current_threshold', torch.tensor(initial_threshold))
        self.register_buffer('iteration_count', torch.tensor(0))
        self.register_buffer('quality_score_history', torch.zeros(1000))
        self.register_buffer('history_ptr', torch.tensor(0))
        self.register_buffer('history_count', torch.tensor(0))

        # Threshold bounds
        self.threshold_min = 0.3
        self.threshold_max = 0.8
        self.warmup_iters = 1000

    def forward(self, feat_ir: torch.Tensor, feat_vis: torch.Tensor):
        fused_feat = torch.cat([feat_ir, feat_vis], dim=1)
        fused_encoded = self.encoder(fused_feat)  # [B, D, H, W]

        B, D, H, W = fused_encoded.shape

        # Extract global features
        global_feat = self.global_pool(fused_encoded).view(B, D)
        projected_feat = self.feat_projection(global_feat)
        projected_feat = F.normalize(projected_feat, dim=1)

        # Query memory bank
        similarity = torch.matmul(projected_feat, self.memory_bank.t())  # [B, M]
        attn_weights = F.softmax(similarity, dim=-1)  # [B, M]

        # Aggregate style features
        aggregated_style_feat = torch.matmul(attn_weights, self.memory_bank)  # [B, D]
        aggregated_style_feat = aggregated_style_feat.unsqueeze(-1).unsqueeze(-1).expand(B, D, 1, 1)

        return aggregated_style_feat, fused_encoded

    @torch.no_grad()
    def update_memory(self, fused_encoded: torch.Tensor, quality_score: torch.Tensor):
        """Update memory bank with minimal fix for diversity"""
        B, D, H, W = fused_encoded.shape
        self.iteration_count += 1

        # Record quality scores
        self._record_quality_scores(quality_score)

        # Get dynamic threshold
        threshold = self._get_dynamic_threshold(quality_score)

        valid_indices = (quality_score > threshold).nonzero(as_tuple=True)[0]
        if valid_indices.numel() == 0:
            return

        for batch_idx in valid_indices:
            batch_feat = fused_encoded[batch_idx:batch_idx + 1].detach()
            global_repr = self.global_pool(batch_feat).view(D)
            enhanced_repr = self.feat_projection(global_repr.unsqueeze(0)).squeeze(0)
            enhanced_repr = F.normalize(enhanced_repr, dim=0)

            # Calculate similarity
            similarity = torch.matmul(enhanced_repr.unsqueeze(0), self.memory_bank.t())
            min_sim_idx = torch.argmin(similarity)
            self.memory_bank[min_sim_idx] = F.normalize(
                (1 - self.momentum) * self.memory_bank[min_sim_idx] +
                self.momentum * enhanced_repr,
                dim=0
            )

    def _record_quality_scores(self, quality_score: torch.Tensor):
        """Record quality scores for statistics"""
        batch_size = quality_score.size(0)
        for i in range(batch_size):
            self.quality_score_history[self.history_ptr] = quality_score[i]
            self.history_ptr = (self.history_ptr + 1) % 1000
            self.history_count = torch.min(self.history_count + 1, torch.tensor(1000))

    def _get_dynamic_threshold(self, quality_score: torch.Tensor) -> torch.Tensor:
        """Get dynamic threshold based on strategy"""
        if self.threshold_strategy == 'adaptive':
            return self._adaptive_threshold(quality_score)
        elif self.threshold_strategy == 'scheduled':
            return self._scheduled_threshold()
        else:
            return self.current_threshold

    def _adaptive_threshold(self, quality_score: torch.Tensor) -> torch.Tensor:
        """Adaptive threshold based on historical statistics"""
        if self.history_count < 100:
            return self.initial_threshold * 0.8

        valid_history = self.quality_score_history[:self.history_count]
        hist_mean = valid_history.mean()
        hist_std = valid_history.std()

        batch_mean = quality_score.mean()
        combined_mean = 0.7 * hist_mean + 0.3 * batch_mean

        k = min(0.5 + self.iteration_count / 10000, 1.5)
        new_threshold = combined_mean + k * hist_std
        new_threshold = torch.clamp(new_threshold, self.threshold_min, self.threshold_max)

        self.current_threshold = 0.9 * self.current_threshold + 0.1 * new_threshold
        return self.current_threshold

    def _scheduled_threshold(self) -> torch.Tensor:
        """Scheduled threshold based on iteration count"""
        if self.iteration_count < self.warmup_iters:
            progress = self.iteration_count.float() / self.warmup_iters
            threshold = self.initial_threshold + (self.threshold_max - self.initial_threshold) * progress * 0.5
        else:
            t = (self.iteration_count - self.warmup_iters).float()
            T = 10000.0
            threshold = self.threshold_min + (self.threshold_max - self.threshold_min) * \
                        (1 + torch.cos(np.pi * (t % T) / T)) / 2

        self.current_threshold = threshold
        return self.current_threshold

class SemanticGraphNetwork(nn.Module):
    """Semantic graph network for cross-modal fusion with modality-specific graph learning"""

    def __init__(self, feature_channels: List[int] = [16, 16, 16, 16],
                 hidden_dim: int = 64, num_layers: int = 2):
        super(SemanticGraphNetwork, self).__init__()

        self.num_nodes = 3  # 3 nodes per graph: original, fusion, residual
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.feature_channels = feature_channels

        # Preliminary fusion convolution
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(sum(feature_channels) * 2, sum(feature_channels), 1, 1, 0),
            nn.BatchNorm2d(sum(feature_channels)),
            nn.ReLU(inplace=True),
            nn.Conv2d(sum(feature_channels), sum(feature_channels), 3, 1, 1),
            nn.BatchNorm2d(sum(feature_channels)),
            nn.ReLU(inplace=True)
        )

        # Node projections for IR graph (IR, fusion, residual)
        self.ir_original_proj = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(sum(feature_channels), hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.ir_fusion_proj = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(sum(feature_channels), hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.ir_residual_proj = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(sum(feature_channels), hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True)
        )

        # Node projections for VIS graph (VIS, fusion, residual)
        self.vis_original_proj = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(sum(feature_channels), hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.vis_fusion_proj = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(sum(feature_channels), hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.vis_residual_proj = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(sum(feature_channels), hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True)
        )

        # Separate graph attention layers for IR and VIS
        self.ir_graph_layers = nn.ModuleList([
            CrossModalGraphAttentionLayer(hidden_dim, hidden_dim, num_heads=4)
            for _ in range(num_layers)
        ])

        self.vis_graph_layers = nn.ModuleList([
            CrossModalGraphAttentionLayer(hidden_dim, hidden_dim, num_heads=4)
            for _ in range(num_layers)
        ])

        # Layer normalization for each graph
        self.ir_layer_norms = nn.ModuleList([
            nn.LayerNorm(hidden_dim) for _ in range(num_layers)
        ])

        self.vis_layer_norms = nn.ModuleList([
            nn.LayerNorm(hidden_dim) for _ in range(num_layers)
        ])

        # Graph representation generators for modality-specific learning
        self.ir_graph_encoder = nn.Sequential(
            nn.Linear(hidden_dim * self.num_nodes, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.vis_graph_encoder = nn.Sequential(
            nn.Linear(hidden_dim * self.num_nodes, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.M = nn.Sequential(
            nn.Conv2d(1,hidden_dim,3,1,1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
        )

        # Adaptive weighting for node importance
        self.ir_node_weights = nn.Parameter(torch.ones(self.num_nodes) / self.num_nodes)
        self.vis_node_weights = nn.Parameter(torch.ones(self.num_nodes) / self.num_nodes)

    def forward(self, ir_features: List[torch.Tensor], vis_features: List[torch.Tensor],M) -> Tuple[
        torch.Tensor, torch.Tensor]:
        """
        Args:
            ir_features: List of IR feature maps [E4, E3, E2, E1]
            vis_features: List of VIS feature maps [E4, E3, E2, E1]

        Returns:
            ir_graph_feat: IR graph representation [B, hidden_dim]
            vis_graph_feat: VIS graph representation [B, hidden_dim]
        """
        B = ir_features[0].size(0)

        # Concatenate features from all encoder levels
        ir_feat_concat = torch.cat(ir_features, dim=1)  # [B, 64, H, W]
        vis_feat_concat = torch.cat(vis_features, dim=1)  # [B, 64, H, W]

        fusion_input = torch.cat([ir_feat_concat, vis_feat_concat], dim=1)  # [B, 128, H, W]
        fusion_feat = self.fusion_conv(fusion_input)  # [B, 64, H, W]

        residual_feat = self.M(M)  # [B, 64, H, W]

        ir_node_original = self.ir_original_proj(ir_feat_concat)  # [B, hidden_dim]
        ir_node_fusion = self.ir_fusion_proj(fusion_feat)  # [B, hidden_dim]
        ir_node_residual = self.ir_residual_proj(residual_feat)  # [B, hidden_dim]

        ir_graph_nodes = torch.stack([ir_node_original, ir_node_fusion, ir_node_residual], dim=1)  # [B, 3, hidden_dim]

        vis_node_original = self.vis_original_proj(vis_feat_concat)  # [B, hidden_dim]
        vis_node_fusion = self.vis_fusion_proj(fusion_feat)  # [B, hidden_dim]
        vis_node_residual = self.vis_residual_proj(residual_feat)  # [B, hidden_dim]

        vis_graph_nodes = torch.stack([vis_node_original, vis_node_fusion, vis_node_residual],
                                      dim=1)  # [B, 3, hidden_dim]

        ir_x = ir_graph_nodes
        for ir_layer, ir_norm in zip(self.ir_graph_layers, self.ir_layer_norms):
            residual = ir_x
            ir_x, _ = ir_layer(ir_x)
            ir_x = ir_norm(ir_x + residual)

        vis_x = vis_graph_nodes
        for vis_layer, vis_norm in zip(self.vis_graph_layers, self.vis_layer_norms):
            residual = vis_x
            vis_x, _ = vis_layer(vis_x)
            vis_x = vis_norm(vis_x + residual)

        ir_node_weights = F.softmax(self.ir_node_weights, dim=0)
        vis_node_weights = F.softmax(self.vis_node_weights, dim=0)

        ir_weighted_nodes = ir_x * ir_node_weights.view(1, self.num_nodes, 1)  # [B, 3, hidden_dim]
        vis_weighted_nodes = vis_x * vis_node_weights.view(1, self.num_nodes, 1)  # [B, 3, hidden_dim]

        ir_graph_flat = ir_weighted_nodes.view(B, -1)  # [B, 3*hidden_dim]
        vis_graph_flat = vis_weighted_nodes.view(B, -1)  # [B, 3*hidden_dim]

        ir_graph_representation = self.ir_graph_encoder(ir_graph_flat)  # [B, hidden_dim,1,1]
        vis_graph_representation = self.vis_graph_encoder(vis_graph_flat)  #  [B, hidden_dim,1,1]


        return ir_graph_representation, vis_graph_representation


class CrossModalGraphAttentionLayer(nn.Module):
    """Cross-modal graph attention layer (keeping original implementation)"""

    def __init__(self, in_dim: int, out_dim: int, num_heads: int = 4):
        super(CrossModalGraphAttentionLayer, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.head_dim = out_dim // num_heads

        assert out_dim % num_heads == 0, "out_dim must be divisible by num_heads"

        self.fc_q = nn.Linear(in_dim, out_dim)
        self.fc_k = nn.Linear(in_dim, out_dim)
        self.fc_v = nn.Linear(in_dim, out_dim)

        self.cross_modal_fc = nn.Linear(out_dim, out_dim)
        self.attn_fc = nn.Linear(2 * self.head_dim, 1)

        self.out_proj = nn.Linear(out_dim, out_dim)
        self.dropout = nn.Dropout(0.1)

        self.pos_embedding = nn.Parameter(torch.randn(3, out_dim))  # 3 nodes per graph

    def forward(self, nodes: torch.Tensor, modality_mask: Optional[torch.Tensor] = None) -> Tuple[
        torch.Tensor, torch.Tensor]:
        """
        Args:
            nodes: [B, N, in_dim] where N=3 (original, fusion, residual)
            modality_mask: Optional mask for nodes

        Returns:
            output: [B, N, out_dim]
            avg_attention: [B, N, N]
        """
        B, N, _ = nodes.shape

        nodes = nodes + self.pos_embedding[:N].unsqueeze(0)

        Q = self.fc_q(nodes).view(B, N, self.num_heads, self.head_dim)
        K = self.fc_k(nodes).view(B, N, self.num_heads, self.head_dim)
        V = self.fc_v(nodes).view(B, N, self.num_heads, self.head_dim)

        attention_scores = []
        attended_values = []

        for head in range(self.num_heads):
            q_h = Q[:, :, head, :]  # [B, N, head_dim]
            k_h = K[:, :, head, :]
            v_h = V[:, :, head, :]

            q_expanded = q_h.unsqueeze(2).expand(B, N, N, -1)
            k_expanded = k_h.unsqueeze(1).expand(B, N, N, -1)
            attn_input = torch.cat([q_expanded, k_expanded], dim=-1)
            e = self.attn_fc(attn_input).squeeze(-1)  # [B, N, N]

            attention = F.softmax(e, dim=-1)
            attended_v = torch.bmm(attention, v_h)  # [B, N, head_dim]

            attention_scores.append(attention)
            attended_values.append(attended_v)

        attended_output = torch.cat(attended_values, dim=-1)  # [B, N, out_dim]
        enhanced_output = self.dropout(self.cross_modal_fc(attended_output))

        if self.in_dim == self.out_dim:
            output = self.out_proj(enhanced_output + nodes)
        else:
            output = self.out_proj(enhanced_output)

        avg_attention = torch.stack(attention_scores).mean(dim=0)
        return output, avg_attention


class CLIPEvaluator(nn.Module):
    """CLIP-based quality evaluator"""
    def __init__(self, device: str = 'cuda', clip_model_name: str = "ViT-B/32"):
        super(CLIPEvaluator, self).__init__()

        try:
            self.clip_model, _ = clip.load(clip_model_name, device=device)
            self.clip_model.eval()
            for p in self.clip_model.parameters():
                p.requires_grad = False
        except Exception as e:
            warnings.warn(f"Failed to load CLIP model: {e}")
            raise

        # Quality prompts
        self.positive_prompts = [
            "A high-quality image with clear contrast, sharp details, proper brightness, clean composition without noise or artifacts.",
            "An excellent image showing sharp details, accurate tones, optimal lighting, and no noise or artifacts.",
            "A clear, well-defined image with precise textures, natural brightness, and flawless composition without imperfections."
        ]

        self.negative_prompts = [
            "Low-quality image with weak contrast, smeared details, incorrect brightness, and noticeable digital noise.",
            "An unclear image with blurred edges, poor exposure, and visible grain or compression artifacts.",
            "Distorted image with dull contrast, missing details, uneven brightness, and distracting noise patterns."
        ]

        # Preprocess text
        self.register_buffer('pos_tokens', clip.tokenize(self.positive_prompts))
        self.register_buffer('neg_tokens', clip.tokenize(self.negative_prompts))

    def preprocess_image(self, image: torch.Tensor) -> torch.Tensor:
        """Preprocess image for CLIP"""
        if image.shape[1] == 1:  # Convert grayscale to RGB
            image = image.repeat(1, 3, 1, 1)

        # Resize and normalize
        image = F.interpolate(image, size=(224, 224), mode='bilinear', align_corners=False)
        image = (image - image.min()) / (image.max() - image.min() + 1e-8)

        # CLIP normalization
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1).to(image.device)
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1).to(image.device)
        image = (image - mean) / std

        return image

    def forward(self, fused_image: torch.Tensor, pseudo_fused: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Evaluate fusion quality"""
        image = self.preprocess_image(fused_image)

        # Extract features
        with torch.no_grad():
            pos_text_features = self.clip_model.encode_text(self.pos_tokens)
            neg_text_features = self.clip_model.encode_text(self.neg_tokens)

        image_features = self.clip_model.encode_image(image)

        # Normalize features
        image_features = F.normalize(image_features, dim=-1)
        pos_text_features = F.normalize(pos_text_features, dim=-1)
        neg_text_features = F.normalize(neg_text_features, dim=-1)

        # Compute similarities
        pos_sim = (image_features @ pos_text_features.T).mean(dim=1)
        neg_sim = (image_features @ neg_text_features.T).mean(dim=1)

        # Quality score
        quality_score = torch.sigmoid((pos_sim - neg_sim))

        return quality_score
    def text(self):
        with torch.no_grad():
            pos_text_features = self.clip_model.encode_text(self.pos_tokens)
        return pos_text_features

def sum_patch(x: torch.Tensor, k: int) -> torch.Tensor:
    """Depthwise mean filter per channel"""
    C = x.size(1)
    kernel = torch.ones((C, 1, 2 * k + 1, 2 * k + 1), device=x.device, dtype=x.dtype)
    kernel = kernel / kernel[0, 0].numel()
    return F.conv2d(x, weight=kernel, stride=1, padding=k, groups=C)


def normalize_tensor(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """Normalize tensor to [0, 1] range"""
    B = x.size(0)
    x_flat = x.view(B, -1)
    minv = x_flat.min(dim=1)[0].view(B, 1, 1, 1)
    maxv = x_flat.max(dim=1)[0].view(B, 1, 1, 1)
    return (x - minv) / (maxv - minv + eps)

from loss import  SimMaxLoss, SimMinLoss
class FusionNet(nn.Module):
    def __init__(self):
        super(FusionNet, self).__init__()
        self.encoder = DualEncoder()
        self.decoder = Decoder()
        self.memory_fusion = MemoryDrivenFusion(
            in_channels=64,
            memory_size=100,
            feature_dim=64,
            threshold_strategy= 'adaptive'
        )
        self.graph_network = SemanticGraphNetwork(
            hidden_dim=64
        )
        self.sim_max_loss = SimMaxLoss(metric='cos', alpha=0.25)
        self.sim_min_loss = SimMinLoss(metric='cos')

    def forward(self,ir_img, vis_img):

        feat_ir, ir_layers, feat_vis, vis_layers = self.encoder(ir_img, vis_img)

        B,C,H,W = feat_ir.shape
        ir_feat_global = torch.nn.functional.adaptive_avg_pool2d(feat_ir, 1).view(B, -1)
        vis_feat_global = torch.nn.functional.adaptive_avg_pool2d(feat_vis, 1).view(B, -1)

        ir_graph_feat, vis_graph_feat = self.graph_network(ir_layers, vis_layers, ir_img - vis_img)

        contrast_loss = (self.sim_min_loss(ir_graph_feat, vis_graph_feat) +
                         self.sim_min_loss(ir_graph_feat, vis_feat_global) +
                         self.sim_min_loss(vis_graph_feat, ir_feat_global) +
                         self.sim_max_loss(ir_graph_feat, ir_feat_global) +
                         self.sim_max_loss(vis_graph_feat, vis_feat_global))

        style_repr, encoded_feat = self.memory_fusion(feat_ir, feat_vis)

        fused_image = self.decoder(encoded_feat, style_repr, ir_graph_feat, vis_graph_feat)

        return fused_image,encoded_feat,contrast_loss


