from PIL import Image
from torchvision import transforms
import torch
import torch.nn.init as init
import torch.nn as nn
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer


Tensor = torch.Tensor
_tokenizer = _Tokenizer()

emotion_detail_map_CASME2 = {
    "disgust": "A combination of lowering the brows, coming together and raising the upper lip",
    "happiness": "A combination of raising the cheeks, pulling the corners of the lips",
    "others": "A combination of lowering the brows, coming together, creating dimples, and raising the chin",
    "repression": "A combination of creating dimples, turning the corners of the mouth down and raising the chin",
    "surprise": "A combination of raising the inner and outer brows and slightly parting the lips"
}

emotion_detail_map_CASME3 = {
    "disgust": "A combination of lowering and knitting the brows, wrinkling the nose, and raising the upper lip",
    "surprise": "A combination of raising the inner and outer brows, widening the eyes, and slightly parting the lips",
    "others": "A combination of subtle non-specific movements such as slight brow lowering, creating dimples, or raising the chin",
    "fear": "A combination of raising the upper eyelids, stretching the lips horizontally, and slightly lowering the brows",
    "anger": "A combination of lowering and drawing the brows together, pressing the lips firmly, and sometimes flaring the nostrils",
    "sad": "A combination of raising the inner brows, pulling the lip corners down, and slightly raising the chin",
    "happy": "A combination of raising the cheeks, pulling the lip corners up, and creating crow’s feet around the eyes"
}

emotion_detail_map_CASME3_4class = {
    "negative": "A combination of disgust, fear, anger, and sad expressions, including features such as brow lowering, nose wrinkling, eyelid raising, lips pressed or stretched, and downturned mouth corners",
    "positive": "A combination of raising the cheeks, pulling the lip corners up, and creating crow’s feet around the eyes",
    "surprise": "A combination of raising the inner and outer brows, widening the eyes, and slightly parting the lips",
    "others": "A combination of subtle non-specific movements such as slight brow lowering, creating dimples, or raising the chin"
}

emotion_detail_map_CASME3_3class = {
    "negative": "A combination of disgust, fear, anger, and sad expressions, including features such as brow lowering, nose wrinkling, eyelid raising, lips pressed or stretched, and downturned mouth corners",
    "positive": "A combination of raising the cheeks, pulling the lip corners up, and creating crow’s feet around the eyes",
    "surprise": "A combination of raising the inner and outer brows, widening the eyes, and slightly parting the lips"
}


def load_clip_to_cpu(cfg):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict())

    return model


def safe_tokenize(text, device):
    # 调用 clip.tokenize 得到 tensor，如果返回的不是 tensor，则转为 tensor
    t = clip.tokenize(text)
    if not isinstance(t, torch.Tensor):
        t = torch.tensor(t, dtype=torch.long)
    return t.to(device)


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = torch.float32

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x


class GlobalPromptLearner(nn.Module):
    def __init__(self, clip_model, emotion_detail_map, n_ctx=5, ctx_init="A photo of", class_token_position="end"):
        super().__init__()
        self.dtype = torch.float32
        self.token_embedding = clip_model.token_embedding
        self.n_ctx = n_ctx
        self.class_token_position = class_token_position

        # 使用 emotion_detail_map 的 key 作为类别名称
        classnames = list(emotion_detail_map.keys())
        n_cls = len(classnames)

        # 获取 token_embedding 所在设备
        device_embed = next(self.token_embedding.parameters()).device

        # 初始化可学习的上下文向量
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = safe_tokenize(ctx_init, device_embed)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(self.dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # 如果没有初始化文本，随机初始化
            ctx_vectors = torch.empty(n_ctx, clip_model.ln_final.weight.shape[0], dtype=self.dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        self.ctx = nn.Parameter(ctx_vectors)  # 可训练的上下文

        # 对类别名称预处理
        classnames = [name.replace("_", " ") for name in classnames]

        # 构造完整的prompt模板，添加介词"of"让句子更通顺
        # 例如："A photo of micro-expression of disgust."
        prompts = [prompt_prefix + " micro-expression of " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([safe_tokenize(p, device_embed) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype)

        # 分离各个部分的embedding
        # [SOS] token
        self.register_buffer("token_prefix", embedding[:, :1, :])

        # "micro-expression of" 是固定的部分
        micro_exp_of_tokens = safe_tokenize("micro-expression of", device_embed)
        micro_exp_of_len = micro_exp_of_tokens.shape[1] - 2  # 去掉SOS和EOS
        self.register_buffer("micro_expression_of_tokens",
                             embedding[:, 1 + n_ctx: 1 + n_ctx + micro_exp_of_len, :])

        # 类别名称长度
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        self.name_lens = name_lens

        # 后缀部分（类别名称 + "." + EOS）
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + micro_exp_of_len:, :])

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts

    def forward(self):
        ctx = self.ctx  # 可学习的上下文
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix  # [SOS]
        micro_exp_of = self.micro_expression_of_tokens  # "micro-expression of" (固定)
        suffix = self.token_suffix  # 类别名 + "." + [EOS] (固定)

        if self.class_token_position == "end":
            # 结构：[SOS] + [可学习上下文] + [micro-expression of] + [类别名] + [.]
            prompts = torch.cat([prefix, ctx, micro_exp_of, suffix], dim=1)
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts_list = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                micro_exp_of_i = micro_exp_of[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                # [SOS] + [上下文前半] + [micro-expression of] + [类别] + [上下文后半] + [.]
                prompt = torch.cat([prefix_i, ctx_i_half1, micro_exp_of_i, class_i, ctx_i_half2, suffix_i], dim=1)
                prompts_list.append(prompt)
            prompts = torch.cat(prompts_list, dim=0)
        elif self.class_token_position == "front":
            prompts_list = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                micro_exp_of_i = micro_exp_of[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                # [SOS] + [micro-expression of] + [类别] + [可学习上下文] + [.]
                prompt = torch.cat([prefix_i, micro_exp_of_i, class_i, ctx_i, suffix_i], dim=1)
                prompts_list.append(prompt)
            prompts = torch.cat(prompts_list, dim=0)
        else:
            raise ValueError("Invalid class_token_position")

        return prompts


class LocalPromptLearner(nn.Module):
    def __init__(self, clip_model, emotion_detail_map, n_ctx=5, ctx_init_local="A detailed view of",
                 local_class_token_position="end"):
        super().__init__()
        self.dtype = torch.float32
        self.token_embedding = clip_model.token_embedding
        self.n_ctx = n_ctx
        self.class_token_position = local_class_token_position

        # 使用 emotion_detail_map 的 value 作为细节描述
        details = list(emotion_detail_map.values())
        n_cls = len(details)

        device_embed = next(self.token_embedding.parameters()).device

        # 初始化可学习的上下文向量
        if ctx_init_local:
            ctx_init_local = ctx_init_local.replace("_", " ")
            n_ctx = len(ctx_init_local.split(" "))
            prompt = safe_tokenize(ctx_init_local, device_embed)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(self.dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init_local
        else:
            print("Initializing a generic local context")
            ctx_vectors = torch.empty(n_ctx, clip_model.ln_final.weight.shape[0], dtype=self.dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        self.ctx = nn.Parameter(ctx_vectors)  # 可训练的上下文

        details = [d.replace("_", " ") for d in details]

        # 构造更通顺的prompt模板，添加介词"showing"
        # 例如："A detailed view of micro-expression showing A combination of lowering the brows..."
        prompts = [prompt_prefix + " micro-expression showing " + d + "." for d in details]

        tokenized_prompts = torch.cat([safe_tokenize(p, device_embed) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype)

        # 分离各个部分
        self.register_buffer("token_prefix", embedding[:, :1, :])

        # "micro-expression showing" 是固定的部分
        micro_exp_showing_tokens = safe_tokenize("micro-expression showing", device_embed)
        micro_exp_showing_len = micro_exp_showing_tokens.shape[1] - 2  # 去掉SOS和EOS
        self.register_buffer("micro_expression_showing_tokens",
                             embedding[:, 1 + n_ctx: 1 + n_ctx + micro_exp_showing_len, :])

        # 详细描述长度
        name_lens = [len(_tokenizer.encode(d)) for d in details]
        self.name_lens = name_lens

        # 后缀部分
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + micro_exp_showing_len:, :])

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts

    def forward(self):
        ctx = self.ctx  # 可学习的上下文
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        micro_exp_showing = self.micro_expression_showing_tokens  # "micro-expression showing" (固定)
        suffix = self.token_suffix  # 详细描述 + "." + [EOS] (固定)

        if self.class_token_position == "end":
            prompts = torch.cat([prefix, ctx, micro_exp_showing, suffix], dim=1)
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts_list = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                micro_exp_showing_i = micro_exp_showing[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                prompt = torch.cat([prefix_i, ctx_i_half1, micro_exp_showing_i, class_i, ctx_i_half2, suffix_i], dim=1)
                prompts_list.append(prompt)
            prompts = torch.cat(prompts_list, dim=0)
        elif self.class_token_position == "front":
            prompts_list = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                micro_exp_showing_i = micro_exp_showing[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                prompt = torch.cat([prefix_i, micro_exp_showing_i, class_i, ctx_i, suffix_i], dim=1)
                prompts_list.append(prompt)
            prompts = torch.cat(prompts_list, dim=0)
        else:
            raise ValueError("Invalid local class_token_position")

        return prompts


class MultiLayerAdapter(nn.Module):
    """
    多层 Adapter：每层结构为 Linear -> LayerNorm -> ReLU -> Linear -> 残差
    支持可配置层数与隐藏维度，保持 dtype 一致。
    """
    def __init__(self, input_dim=512, hidden_dim=512, num_layers=3, dropout=0.0, dtype=torch.float32):
        super().__init__()
        self.dtype = dtype
        self.num_layers = num_layers

        layers = []
        self.attn_maps = []  # 用于保存每一层的“注意力图”

        for _ in range(num_layers):
            block = nn.ModuleDict({
                "linear1": nn.Linear(input_dim, hidden_dim).to(dtype),
                "ln1": nn.LayerNorm(hidden_dim).to(dtype),
                "act": nn.ReLU(),
                "drop": nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
                "linear2": nn.Linear(hidden_dim, input_dim).to(dtype),
            })
            layers.append(block)

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        x = x.to(self.dtype)
        for block in self.layers:
            residual = x
            # 在每一层之后保存中间激活（即注意力映射）
            attn_map = x.clone()  # 你可以选择保存 `x` 或者某些中间激活
            self.attn_maps.append(attn_map)  # 保存注意力图

            x = block["linear1"](x)
            x = block["ln1"](x)
            x = block["act"](x)
            x = block["drop"](x)
            x = block["linear2"](x)
            x = x + residual  # 残差连接

        return x


class SimpleAdapter(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=512, dtype=torch.float32):
        super(SimpleAdapter, self).__init__()
        self.dtype = dtype

        # 只用两个Linear层，没有LayerNorm
        self.down_proj = nn.Linear(input_dim, hidden_dim // 4, bias=True)  # 大幅缩小hidden_dim
        self.up_proj = nn.Linear(hidden_dim // 4, input_dim, bias=True)
        self.dropout = nn.Dropout(0.1)
        self.act = nn.GELU()  # 更稳定的激活函数

        # 转换到正确的dtype
        self.down_proj = self.down_proj.to(self.dtype)
        self.up_proj = self.up_proj.to(self.dtype)

        # 非常保守的初始化
        self._initialize_weights()

    def _initialize_weights(self):
        with torch.no_grad():
            # 使用标准初始化
            nn.init.kaiming_uniform_(self.down_proj.weight, a=1)
            nn.init.kaiming_uniform_(self.up_proj.weight, a=1)
            nn.init.zeros_(self.down_proj.bias)
            nn.init.zeros_(self.up_proj.bias)

            # 缩放最后一层权重
            self.up_proj.weight.data *= 0.01

    def forward(self, x):
        if torch.isnan(x).any():
            return x

        original = x.clone()
        x = x.to(self.dtype)

        # 简单的下采样-上采样结构
        x = self.down_proj(x)
        if torch.isnan(x).any():
            return original

        x = self.act(x)
        x = self.dropout(x)
        x = self.up_proj(x)

        if torch.isnan(x).any():
            return original

        # 非常小的残差连接
        return original + 0.01 * x


# 定义一个新的加权层来组合logits
class WeightedLogits(nn.Module):
    def __init__(self):
        super(WeightedLogits, self).__init__()
        # 定义一个简单的线性层，用于加权两个logits
        self.weighted_fc = nn.Linear(2, 1)  # 输入2个logits，输出1个加权后的logits

        # 自定义初始化
        init.xavier_uniform_(self.weighted_fc.weight)  # 使用Xavier均匀初始化权重
        init.zeros_(self.weighted_fc.bias)  # 偏置初始化为零

    def forward(self, logits_global, logits_local):
        # 确保两个输入的dtype一致
        logits_global = logits_global.to(torch.float32)
        logits_local = logits_local.to(torch.float32)

        # 将两个logits拼接在一起
        logits_combined = torch.stack([logits_global, logits_local], dim=-1)  # [batch_size, num_classes, 2]

        # 使用线性层对两个logits进行加权
        weighted_logits = self.weighted_fc(logits_combined)  # [batch_size, num_classes, 1]

        # 去掉最后一个维度并返回加权后的logits
        return weighted_logits.squeeze(-1)


class TwoBranchCLIP(nn.Module):
    def __init__(self, clip_model, emotion_detail_map, n_ctx=5):
        super().__init__()
        # 图像编码器：CLIP 的视觉分支
        self.image_encoder1 = clip_model.visual
        self.image_encoder2 = clip_model.visual
        # 文本编码器：与 CoOp 中相同
        self.text_encoder = TextEncoder(clip_model)

        # 全局提示分支：使用字典中 key 作为类别名称
        self.global_prompt_learner = GlobalPromptLearner(clip_model, emotion_detail_map)
        # 局部提示分支：使用字典中 value 作为细节描述
        self.local_prompt_learner = LocalPromptLearner(clip_model, emotion_detail_map)
        self.logit_scale = clip_model.logit_scale
        self.dtype = torch.float32

        #  添加adapter
        self.adapter1 = MultiLayerAdapter(input_dim=512, hidden_dim=512, num_layers=3, dropout=0.3, dtype=self.dtype)
        #  Logits_Sum
        self.logits_sum = WeightedLogits()

        # 冻结 image_encoder1 和 image_encoder2 的参数
        for param in self.image_encoder1.parameters():
            param.requires_grad = False

        for param in self.image_encoder2.parameters():
            param.requires_grad = False

    def forward(self, apex_image, flow_image):

        #  确保在进行矩阵乘法时所有张量都是self.dtype类型
        logit_scale = torch.clamp(self.logit_scale.exp(), min=1.0, max=100.0)

        # 提取图像特征并检查NaN
        apex_features = self.image_encoder1(apex_image.type(self.dtype))
        apex_features = self.adapter1(apex_features)  # 通过适配器处理特征
        flow_features = self.image_encoder2(flow_image.type(self.dtype))
        flow_features = self.adapter2(flow_features)

        # 归一化
        eps = 1e-8
        apex_features = apex_features / (apex_features.norm(dim=-1, keepdim=True) + eps)
        flow_features = flow_features / (flow_features.norm(dim=-1, keepdim=True) + eps)

        #  生成文本提示
        global_prompts = self.global_prompt_learner()  # (n_cls, L, dim)
        local_prompts = self.local_prompt_learner()  # (n_cls, L, dim)

        #  文本编码
        global_text_features = self.text_encoder(global_prompts, self.global_prompt_learner.tokenized_prompts)
        local_text_features = self.text_encoder(local_prompts, self.local_prompt_learner.tokenized_prompts)

        # 计算logits
        logits_global = logit_scale * (apex_features @ global_text_features.t())
        logits_local = logit_scale * (flow_features @ local_text_features.t())

        # 返回logits
        return logits_global, logits_local


if __name__ == "__main__":

    def load_image(image_path):
        # 打开图像并进行预处理
        image = Image.open(image_path).convert("RGB")
        transform = transforms.Compose([
            transforms.Resize((224, 224)),  # 调整大小到 224x224
            transforms.ToTensor(),  # 转换为 Tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 进行标准化
        ])
        image_tensor = transform(image).unsqueeze(0).cuda()  # 添加 batch 维度并移动到 GPU
        return image_tensor

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 加载 CLIP 模型和自定义模型
    clip_model, _ = clip.load("ViT-B/32", device=device)
    # 实例化 TwoBranchCLIP 模型
    model = TwoBranchCLIP(clip_model, emotion_detail_map_CASME3_3class).to(device).float()

    # 进行测试
    apex_image_path = r"E:\Xyj_CLIP\111_Double_CLIP_Visualize\test_img\apex_spNO.1_j_28_33_happy.jpg"
    flow_image_path = r"E:\Xyj_CLIP\111_Double_CLIP_Visualize\test_img\flow_spNO.1_b_166_175_disgust.jpg"
    apex_image = load_image(apex_image_path)
    flow_image = load_image(flow_image_path)

    logits_global, logits_local = model(apex_image, flow_image)

