import torch
import torch.nn.init as init
import torch.nn as nn
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from HCP_Groups import TextEncoder, GlobalPromptLearner, LocalPromptLearner


Tensor = torch.Tensor
_tokenizer = _Tokenizer()


class MultiLayerAdapter(nn.Module):
    """
    多层 Adapter：每层结构为 Linear -> LayerNorm -> ReLU -> Linear -> 残差
    支持可配置层数与隐藏维度，保持 dtype 一致。
    """
    def __init__(self, input_dim=512, hidden_dim=512, num_layers=3, dropout=0.0, dtype=torch.float32):
        super().__init__()
        self.dtype = dtype
        self.num_layers = num_layers

        layers = []
        self.attn_maps = []  # 用于保存每一层的“注意力图”

        for _ in range(num_layers):
            block = nn.ModuleDict({
                "linear1": nn.Linear(input_dim, hidden_dim).to(dtype),
                "ln1": nn.LayerNorm(hidden_dim).to(dtype),
                "act": nn.ReLU(),
                "drop": nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
                "linear2": nn.Linear(hidden_dim, input_dim).to(dtype),
            })
            layers.append(block)

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        x = x.to(self.dtype)
        for block in self.layers:
            residual = x
            # 在每一层之后保存中间激活（即注意力映射）
            attn_map = x.clone()  # 你可以选择保存 `x` 或者某些中间激活
            self.attn_maps.append(attn_map)  # 保存注意力图

            x = block["linear1"](x)
            x = block["ln1"](x)
            x = block["act"](x)
            x = block["drop"](x)
            x = block["linear2"](x)
            x = x + residual  # 残差连接

        return x


class SimpleAdapter(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=512, dtype=torch.float32):
        super(SimpleAdapter, self).__init__()
        self.dtype = dtype

        # 只用两个Linear层，没有LayerNorm
        self.down_proj = nn.Linear(input_dim, hidden_dim // 4, bias=True)  # 大幅缩小hidden_dim
        self.up_proj = nn.Linear(hidden_dim // 4, input_dim, bias=True)
        self.dropout = nn.Dropout(0.1)
        self.act = nn.GELU()  # 更稳定的激活函数

        # 转换到正确的dtype
        self.down_proj = self.down_proj.to(self.dtype)
        self.up_proj = self.up_proj.to(self.dtype)

        # 非常保守的初始化
        self._initialize_weights()

    def _initialize_weights(self):
        with torch.no_grad():
            # 使用标准初始化
            nn.init.kaiming_uniform_(self.down_proj.weight, a=1)
            nn.init.kaiming_uniform_(self.up_proj.weight, a=1)
            nn.init.zeros_(self.down_proj.bias)
            nn.init.zeros_(self.up_proj.bias)

            # 缩放最后一层权重
            self.up_proj.weight.data *= 0.01

    def forward(self, x):
        if torch.isnan(x).any():
            return x

        original = x.clone()
        x = x.to(self.dtype)

        # 简单的下采样-上采样结构
        x = self.down_proj(x)
        if torch.isnan(x).any():
            return original

        x = self.act(x)
        x = self.dropout(x)
        x = self.up_proj(x)

        if torch.isnan(x).any():
            return original

        # 非常小的残差连接
        return original + 0.01 * x


# 定义一个新的加权层来组合logits
class WeightedLogits(nn.Module):
    def __init__(self):
        super(WeightedLogits, self).__init__()
        # 定义一个简单的线性层，用于加权两个logits
        self.weighted_fc = nn.Linear(2, 1)  # 输入2个logits，输出1个加权后的logits

        # 自定义初始化
        init.xavier_uniform_(self.weighted_fc.weight)  # 使用Xavier均匀初始化权重
        init.zeros_(self.weighted_fc.bias)  # 偏置初始化为零

    def forward(self, logits_global, logits_local):
        # 确保两个输入的dtype一致
        logits_global = logits_global.to(torch.float32)
        logits_local = logits_local.to(torch.float32)

        # 将两个logits拼接在一起
        logits_combined = torch.stack([logits_global, logits_local], dim=-1)  # [batch_size, num_classes, 2]

        # 使用线性层对两个logits进行加权
        weighted_logits = self.weighted_fc(logits_combined)  # [batch_size, num_classes, 1]

        # 去掉最后一个维度并返回加权后的logits
        return weighted_logits.squeeze(-1)


class TwoBranchCLIP(nn.Module):
    def __init__(self, clip_model, emotion_detail_map, n_ctx=5):
        super().__init__()
        # 图像编码器：CLIP 的视觉分支
        self.image_encoder1 = clip_model.visual
        self.image_encoder2 = clip_model.visual
        # 文本编码器：与 CoOp 中相同
        self.text_encoder = TextEncoder(clip_model)

        # 全局提示分支：使用字典中 key 作为类别名称
        self.global_prompt_learner = GlobalPromptLearner(clip_model, emotion_detail_map)
        # 局部提示分支：使用字典中 value 作为细节描述
        self.local_prompt_learner = LocalPromptLearner(clip_model, emotion_detail_map)
        self.logit_scale = clip_model.logit_scale
        self.dtype = torch.float32

        #  添加adapter
        self.adapter1 = MultiLayerAdapter(input_dim=512, hidden_dim=512, num_layers=3, dropout=0.3, dtype=self.dtype)
        self.adapter2 = MultiLayerAdapter(input_dim=512, hidden_dim=512, num_layers=3, dropout=0.3, dtype=self.dtype)
        #  Logits_Sum
        self.logits_sum = WeightedLogits()

        # 冻结 image_encoder1 和 image_encoder2 的参数
        for param in self.image_encoder1.parameters():
            param.requires_grad = False

        for param in self.image_encoder2.parameters():
            param.requires_grad = False

    def forward(self, apex_image, flow_image):

        #  确保在进行矩阵乘法时所有张量都是self.dtype类型
        logit_scale = torch.clamp(self.logit_scale.exp(), min=1.0, max=100.0)

        apex_features = self.image_encoder1(apex_image.type(self.dtype))
        apex_features = self.adapter1(apex_features)  # 通过适配器处理特征
        flow_features = self.image_encoder2(flow_image.type(self.dtype))
        flow_features = self.adapter2(flow_features)

        # 归一化
        eps = 1e-8
        apex_features = apex_features / (apex_features.norm(dim=-1, keepdim=True) + eps)
        flow_features = flow_features / (flow_features.norm(dim=-1, keepdim=True) + eps)

        #  生成文本提示
        global_prompts = self.global_prompt_learner()  # (n_cls, L, dim)
        local_prompts = self.local_prompt_learner()  # (n_cls, L, dim)

        #  文本编码
        global_text_features = self.text_encoder(global_prompts, self.global_prompt_learner.tokenized_prompts)
        local_text_features = self.text_encoder(local_prompts, self.local_prompt_learner.tokenized_prompts)

        # 计算logits
        logits_global = logit_scale * (apex_features @ global_text_features.t())
        logits_local = logit_scale * (flow_features @ local_text_features.t())

        # 返回logits
        return logits_global, logits_local



