import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from models.vision_transformer import Prototype_Block
from functools import partial


class INP_Former(nn.Module):
    def __init__(
            self,
            encoder,
            bottleneck,
            aggregation,
            decoder,
            target_layers =[2, 3, 4, 5, 6, 7, 8, 9],
            fuse_layer_encoder =[[0, 1, 2, 3, 4, 5, 6, 7]],
            fuse_layer_decoder =[[0, 1, 2, 3, 4, 5, 6, 7]],
            remove_class_token=False,
            encoder_require_grad_layer=[],
            prototype_token=None,
            use_proto_relation=False,
            use_three_loss=True,
            lambda_repulsion=0.1,
            lambda_coverage=0.05,
            # 新增融合配置参数
            fusion_config=None,
    ) -> None:
        super(INP_Former, self).__init__()
        self.encoder = encoder
        self.bottleneck = bottleneck
        self.aggregation = aggregation
        self.decoder = decoder
        self.target_layers = target_layers
        self.fuse_layer_encoder = fuse_layer_encoder
        self.fuse_layer_decoder = fuse_layer_decoder
        self.remove_class_token = remove_class_token
        self.encoder_require_grad_layer = encoder_require_grad_layer
        self.prototype_token = prototype_token[0]
        self.use_proto_relation = use_proto_relation
        self.use_three_loss = use_three_loss
        self.lambda_repulsion = lambda_repulsion
        self.lambda_coverage = lambda_coverage
        
        # 融合配置 - 使用默认配置并合并用户配置
        default_fusion_config = {
            'mode': 'all',  # 'all', 'encoder_only', 'decoder_only', 'late_fusion', 'none'
            'method': 'concat_mlp',  # 'concat_mlp', 'attention', 'gate', 'adaptive'
            'fusion_layers': 'all',  # 'all', 'high', 'low', [2,4,6] 或具体层索引列表
            'depth_weight': 0.5,  # 深度图的权重 (0-1)
            'learnable_weight': True,  # 是否使用可学习的融合权重
            'separate_decoders': False,  # 是否使用分离的解码器
            'cross_attention': False,  # 是否在融合中使用交叉注意力
        }
        
        if fusion_config is not None:
            default_fusion_config.update(fusion_config)
        self.fusion_config = default_fusion_config
            
        self._setup_fusion_modules()
            
        if not hasattr(self.encoder, 'num_register_tokens'):
            self.encoder.num_register_tokens = 0
    
    def _setup_fusion_modules(self):
        """根据配置初始化融合模块"""
        dim = self.encoder.embed_dim
        num_layers = len(self.target_layers)
        
        # 确定需要融合的层
        self.fusion_layers = self._get_fusion_layers()
        
        # 初始化可学习权重
        if self.fusion_config['learnable_weight']:
            self.fusion_weights = nn.Parameter(
                torch.ones(num_layers, 2) * torch.tensor([1-self.fusion_config['depth_weight'], 
                                                         self.fusion_config['depth_weight']])
            )
        
        # 根据融合方法初始化模块
        method = self.fusion_config['method']
        
        if method == 'concat_mlp':
            self.fusion_mlps = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(dim * 2, dim * 4),
                    nn.GELU(),
                    nn.Dropout(0.1),
                    nn.Linear(dim * 4, dim),
                    nn.LayerNorm(dim)
                ) for _ in range(num_layers)
            ])
            
        elif method == 'attention':
            self.fusion_attention = nn.ModuleList([
                nn.MultiheadAttention(embed_dim=dim, num_heads=8, batch_first=True)
                for _ in range(num_layers)
            ])
            self.fusion_norm = nn.ModuleList([
                nn.LayerNorm(dim) for _ in range(num_layers)
            ])
            
        elif method == 'gate':
            self.gate_networks = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(dim * 2, dim),
                    nn.Sigmoid()
                ) for _ in range(num_layers)
            ])
            
        elif method == 'adaptive':
            # 自适应融合：根据特征相似度动态调整融合策略
            self.adaptive_gates = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(dim * 2, 1),
                    nn.Sigmoid()
                ) for _ in range(num_layers)
            ])
            self.feature_proj = nn.ModuleList([
                nn.Linear(dim, dim) for _ in range(num_layers)
            ])
        
        # 交叉注意力模块
        if self.fusion_config['cross_attention']:
            self.cross_attn_rgb2depth = nn.ModuleList([
                nn.MultiheadAttention(embed_dim=dim, num_heads=4, batch_first=True)
                for _ in range(num_layers)
            ])
            self.cross_attn_depth2rgb = nn.ModuleList([
                nn.MultiheadAttention(embed_dim=dim, num_heads=4, batch_first=True)
                for _ in range(num_layers)
            ])
        
        # 分离解码器
        if self.fusion_config['separate_decoders']:
            self.depth_decoder = nn.ModuleList([
                Prototype_Block(
                    dim=dim, 
                    num_heads=self.encoder.embed_dim // 64,  # 根据embed_dim计算num_heads
                    mlp_ratio=4.,
                    qkv_bias=True, 
                    norm_layer=partial(nn.LayerNorm, eps=1e-8)
                ) for _ in range(len(self.decoder))
            ])
    
    def _get_fusion_layers(self):
        """确定需要融合的层索引"""
        fusion_layers = self.fusion_config['fusion_layers']
        
        if fusion_layers == 'all':
            return list(range(len(self.target_layers)))
        elif fusion_layers == 'high':
            # 高层：后一半层
            return list(range(len(self.target_layers)//2, len(self.target_layers)))
        elif fusion_layers == 'low':
            # 低层：前一半层
            return list(range(len(self.target_layers)//2))
        elif isinstance(fusion_layers, list):
            # 具体指定的层索引
            return [i for i, layer in enumerate(self.target_layers) if layer in fusion_layers]
        else:
            return []
    
    def should_fuse_at_layer(self, layer_idx):
        """判断当前层是否需要融合"""
        mode = self.fusion_config['mode']
        return (mode in ['all', 'encoder_only'] and layer_idx in self.fusion_layers)
    
    def should_fuse_at_decoder(self, layer_idx):
        """判断解码器当前层是否需要融合"""
        mode = self.fusion_config['mode']
        return (mode in ['all', 'decoder_only'] and layer_idx in self.fusion_layers)
    
    def fuse_features(self, rgb_features, depth_features, layer_idx=0, stage='encoder'):
        """融合RGB和深度特征"""
        if rgb_features is None or depth_features is None:
            return rgb_features if rgb_features is not None else depth_features
        
        method = self.fusion_config['method']
        
        # 获取融合权重
        if self.fusion_config['learnable_weight']:
            weights = F.softmax(self.fusion_weights[layer_idx], dim=0)
            rgb_weight, depth_weight = weights[0], weights[1]
        else:
            rgb_weight = 1 - self.fusion_config['depth_weight']
            depth_weight = self.fusion_config['depth_weight']
        
        # 交叉注意力预处理
        if self.fusion_config['cross_attention']:
            rgb_enhanced, _ = self.cross_attn_rgb2depth[layer_idx](
                rgb_features, depth_features, depth_features
            )
            depth_enhanced, _ = self.cross_attn_depth2rgb[layer_idx](
                depth_features, rgb_features, rgb_features
            )
            rgb_features = rgb_features + rgb_enhanced
            depth_features = depth_features + depth_enhanced
        
        if method == 'concat_mlp':
            concat_features = torch.cat([rgb_features * rgb_weight, 
                                       depth_features * depth_weight], dim=-1)
            return self.fusion_mlps[layer_idx](concat_features)
            
        elif method == 'attention':
            # 使用注意力机制融合 - 恢复到原始的正确实现
            norm_rgb = self.fusion_norm[layer_idx](rgb_features)
            norm_depth = self.fusion_norm[layer_idx](depth_features)
            attn_output, _ = self.fusion_attention[layer_idx](
                query=norm_rgb,
                key=norm_depth,
                value=norm_depth
            )
            return rgb_features + attn_output
            
        elif method == 'gate':
            # 门控融合
            concat_features = torch.cat([rgb_features, depth_features], dim=-1)
            gate = self.gate_networks[layer_idx](concat_features)
            return gate * rgb_features + (1 - gate) * depth_features
            
        elif method == 'adaptive':
            # 自适应融合
            concat_features = torch.cat([rgb_features, depth_features], dim=-1)
            adaptive_weight = self.adaptive_gates[layer_idx](concat_features)
            
            # 投影特征
            proj_rgb = self.feature_proj[layer_idx](rgb_features)
            proj_depth = self.feature_proj[layer_idx](depth_features)
            
            return adaptive_weight * proj_rgb + (1 - adaptive_weight) * proj_depth
        
        else:
            # 简单加权平均
            return rgb_weight * rgb_features + depth_weight * depth_features

    def forward_encoder(self, x, depth=None):
        """编码器前向传播"""
        x = self.encoder.prepare_tokens(x)
        B, L, _ = x.shape
        
        if depth is not None:
            depth = self.encoder.prepare_tokens(depth)
        
        en_list = []
        depth_en_list = []
        
        for i, blk in enumerate(self.encoder.blocks):
            if i <= self.target_layers[-1]:
                if i in self.encoder_require_grad_layer:
                    x = blk(x)
                    if depth is not None:
                        depth = blk(depth)
                else:
                    with torch.no_grad():
                        x = blk(x)
                        if depth is not None:
                            depth = blk(depth)
            else:
                continue
                
            if i in self.target_layers:
                layer_idx = self.target_layers.index(i)
                
                # 根据配置决定是否在编码器阶段融合
                if depth is not None and self.should_fuse_at_layer(layer_idx):
                    fused_features = self.fuse_features(x, depth, layer_idx, 'encoder')
                    en_list.append(fused_features)
                    depth_en_list.append(None)  # 已融合，不需要单独保存
                else:
                    en_list.append(x)
                    if depth is not None:
                        depth_en_list.append(depth)
                    else:
                        depth_en_list.append(None)
        
        return en_list, depth_en_list

    def forward_decoder(self, x, en_list, depth_en_list, agg_prototype, proto_relation):
        """解码器前向传播 - 原始实现"""
        for i, blk in enumerate(self.bottleneck):
            x = blk(x)

        de_list = []
        depth_de_list = []
        
        for i, blk in enumerate(self.decoder):
            # 解码器阶段融合
            if self.should_fuse_at_decoder(i) and depth_en_list[i] is not None:
                decoder_input = self.fuse_features(en_list[i], depth_en_list[i], i, 'decoder')
            else:
                decoder_input = en_list[i]
            
            # 如果使用分离解码器
            if self.fusion_config['separate_decoders'] and depth_en_list[i] is not None:
                x_rgb = blk(decoder_input, agg_prototype, proto_relation=proto_relation)
                x_depth = self.depth_decoder[i](depth_en_list[i], agg_prototype, proto_relation=proto_relation)
                
                # 在输出阶段融合
                x = self.fuse_features(x_rgb, x_depth, i, 'decoder')
                de_list.append(x)
                depth_de_list.append(x_depth)
            else:
                x = blk(decoder_input, agg_prototype, proto_relation=proto_relation)
                de_list.append(x)
                depth_de_list.append(None)
        
        return de_list, depth_de_list

    def forward_decoder_with_accumulation(self, x, en_list, depth_en_list, agg_prototype, proto_relation):
        """解码器前向传播 - 改进版本（保持特征累积）"""
        # Bottleneck处理
        for i, blk in enumerate(self.bottleneck):
            x = blk(x)
        
        de_list = []
        depth_de_list = []
        depth_x = x  # 新增：为深度解码器也维护一个累积特征
        
        for i, blk in enumerate(self.decoder):
            # 准备跳跃连接特征
            skip_connection = None
            if i < len(en_list):
                if self.should_fuse_at_decoder(i) and depth_en_list[i] is not None:
                    skip_connection = self.fuse_features(en_list[i], depth_en_list[i], i, 'decoder')
                else:
                    skip_connection = en_list[i]
            
            # 融合解码器累积特征和跳跃连接
            if skip_connection is not None:
                decoder_input = x + skip_connection  # 简单相加
            else:
                decoder_input = x
            
            # 分离解码器处理
            if self.fusion_config['separate_decoders'] and depth_en_list[i] is not None:
                x_rgb = blk(decoder_input, agg_prototype, proto_relation=proto_relation)
                
                # 新增：深度解码器也使用跳跃连接
                depth_skip = depth_en_list[i] if depth_en_list[i] is not None else None
                if depth_skip is not None:
                    depth_decoder_input = depth_x + depth_skip
                else:
                    depth_decoder_input = depth_x
                x_depth = self.depth_decoder[i](depth_decoder_input, agg_prototype, proto_relation=proto_relation)
                
                x = self.fuse_features(x_rgb, x_depth, i, 'decoder')
                depth_x = x_depth  # 新增：更新深度累积特征
                de_list.append(x)
                depth_de_list.append(x_depth)
            else:
                x = blk(decoder_input, agg_prototype, proto_relation=proto_relation)
                depth_x = x  # 新增：如果没有分离，深度累积特征就是融合特征
                de_list.append(x)
                depth_de_list.append(None)
        
        # 不在这里反转，让forward方法统一处理
        return de_list, depth_de_list

    def get_inp(self, x, depth=None):
        """获取原型，支持可选的深度图输入"""
        en_list, depth_en_list = self.forward_encoder(x, depth)
        
        # 原型聚合时的特征选择
        if self.fusion_config['mode'] == 'late_fusion' and depth is not None:
            # 后期融合：分别计算原型，最后融合
            x_rgb = self.fuse_feature(en_list)
            x_depth = self.fuse_feature([d for d in depth_en_list if d is not None])
            x = self.fuse_features(x_rgb, x_depth, 0, 'prototype')
        else:
            x = self.fuse_feature(en_list)

        B = x.shape[0]
        agg_prototype = self.prototype_token
        for i, blk in enumerate(self.aggregation):
            agg_prototype = blk(agg_prototype.unsqueeze(0).repeat((B, 1, 1)), x)
        return agg_prototype

    def forward(self, x, depth=None):
        """主前向传播函数"""
        # 编码器阶段
        en_list, depth_en_list = self.forward_encoder(x, depth)
        
        side = int(math.sqrt(en_list[0].shape[1] - 1 - self.encoder.num_register_tokens))

        if self.remove_class_token:
            en_list = [e[:, 1 + self.encoder.num_register_tokens:, :] for e in en_list]
            depth_en_list = [d[:, 1 + self.encoder.num_register_tokens:, :] if d is not None else None 
                           for d in depth_en_list]

        # 原型聚合
        if self.fusion_config['mode'] == 'late_fusion' and depth is not None:
            # 后期融合模式
            x_rgb = self.fuse_feature(en_list)
            x_depth = self.fuse_feature([d for d in depth_en_list if d is not None])
            x = self.fuse_features(x_rgb, x_depth, 0, 'prototype')
        else:
            # 常规模式
            fusion_list = []
            for i, (rgb_feat, depth_feat) in enumerate(zip(en_list, depth_en_list)):
                if depth_feat is not None and not self.should_fuse_at_layer(i):
                    # 如果编码器阶段没有融合，现在融合
                    fusion_list.append(self.fuse_features(rgb_feat, depth_feat, i, 'late'))
                else:
                    fusion_list.append(rgb_feat)
            x = self.fuse_feature(fusion_list)

        B = x.shape[0]
        agg_prototype = self.prototype_token
        for i, blk in enumerate(self.aggregation):
            agg_prototype = blk(agg_prototype.unsqueeze(0).repeat((B, 1, 1)), x)
            
        g_loss = self.out_loss(x, agg_prototype)
        
        # 根据use_proto_relation决定是否使用proto_relation
        if self.use_proto_relation:
            proto_relation = torch.softmax(-self.distribution, dim=2)
        else:
            proto_relation = None
        
        # 解码器阶段 - 通过注释选择不同的实现
        # 方案1：原始实现（直接使用编码器特征，无累积）
        # de_list, depth_de_list = self.forward_decoder(
        #     x, en_list, depth_en_list, agg_prototype, proto_relation
        # )
        
        # 方案2：累积实现（解码器特征累积+跳跃连接）
        de_list, depth_de_list = self.forward_decoder_with_accumulation(
            x, en_list, depth_en_list, agg_prototype, proto_relation
        )
        
        # 统一反转处理（无论使用哪种方法，都在这里统一反转）
        de_list = de_list[::-1]
        
        # 特征层融合
        en = []
        for idxs in self.fuse_layer_encoder:
            layer_features = []
            for idx in idxs:
                if depth_en_list[idx] is not None and not self.should_fuse_at_layer(idx):
                    layer_features.append(self.fuse_features(en_list[idx], depth_en_list[idx], idx, 'output'))
                else:
                    layer_features.append(en_list[idx])
            en.append(self.fuse_feature(layer_features))
            
        de = []
        for idxs in self.fuse_layer_decoder:
            rgb_features = [de_list[idx] for idx in idxs]
            rgb_fused = self.fuse_feature(rgb_features)
            
            # 新增：如果有深度解码特征，也融合进来
            if self.fusion_config['separate_decoders'] and any(depth_de_list[idx] is not None for idx in idxs if idx < len(depth_de_list)):
                depth_features = [depth_de_list[idx] for idx in idxs 
                                 if idx < len(depth_de_list) and depth_de_list[idx] is not None]
                if depth_features:
                    depth_fused = self.fuse_feature(depth_features)
                    final_fused = self.fuse_features(rgb_fused, depth_fused, 0, 'final_decoder')
                    de.append(final_fused)
                else:
                    de.append(rgb_fused)
            else:
                de.append(rgb_fused)

        if not self.remove_class_token:
            en = [e[:, 1 + self.encoder.num_register_tokens:, :] for e in en]
            de = [d[:, 1 + self.encoder.num_register_tokens:, :] for d in de]

        en = [e.permute(0, 2, 1).reshape([x.shape[0], -1, side, side]).contiguous() for e in en]
        de = [d.permute(0, 2, 1).reshape([x.shape[0], -1, side, side]).contiguous() for d in de]
        
        return en, de, g_loss

    def fuse_feature(self, feat_list):
        return torch.stack(feat_list, dim=1).mean(dim=1)

    def gather_loss(self, query, keys):
        self.distribution = 1. - F.cosine_similarity(query.unsqueeze(2), keys.unsqueeze(1), dim=-1)
        self.distance, self.cluster_index = torch.min(self.distribution, dim=2)
        gather_loss = self.distance.mean()
        return gather_loss
        
    def out_loss(self,query,keys):
        B,  N ,D= keys.shape
        self.distribution = 1. - F.cosine_similarity(query.unsqueeze(2), keys.unsqueeze(1), dim=-1) # [B, L, N]
        self.distance, self.cluster_index = torch.min(self.distribution, dim=2) # [B, L], [B, L]
        gather_term = self.distance.mean()

        if not self.use_three_loss:
            return gather_term
            
        # 计算 repulsion_term
        pairwise_dist = 1. - F.cosine_similarity(keys.unsqueeze(2), keys.unsqueeze(1), dim=-1) # [B, N, N]
        mask = torch.eye(N, device=keys.device).bool().unsqueeze(0)
        repulsion_dist = pairwise_dist.masked_fill(mask, float('nan'))
        repulsion_term = torch.nanmean(repulsion_dist)

        # 计算 dead_prototype_penalty
        one_hot_assignments = F.one_hot(self.cluster_index, num_classes=N)
        prototype_counts = one_hot_assignments.sum(dim=1)
        is_dead = (prototype_counts == 0)
        dead_prototype_penalty = is_dead.sum(dim=1).float().mean()

        # 组合所有项
        total_loss = gather_term + self.lambda_repulsion * repulsion_term + self.lambda_coverage * dead_prototype_penalty

        return total_loss



































































