import torch
import torch.nn as nn
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer


_tokenizer = _Tokenizer()

emotion_detail_map_CASME2 = {
    "disgust": "A combination of lowering the brows, coming together and raising the upper lip",
    "happiness": "A combination of raising the cheeks, pulling the corners of the lips",
    "others": "A combination of lowering the brows, coming together, creating dimples, and raising the chin",
    "repression": "A combination of creating dimples, turning the corners of the mouth down and raising the chin",
    "surprise": "A combination of raising the inner and outer brows and slightly parting the lips"
}

emotion_detail_map_CASME3 = {
    "disgust": "A combination of lowering and knitting the brows, wrinkling the nose, and raising the upper lip",
    "surprise": "A combination of raising the inner and outer brows, widening the eyes, and slightly parting the lips",
    "others": "A combination of subtle non-specific movements such as slight brow lowering, creating dimples, or raising the chin",
    "fear": "A combination of raising the upper eyelids, stretching the lips horizontally, and slightly lowering the brows",
    "anger": "A combination of lowering and drawing the brows together, pressing the lips firmly, and sometimes flaring the nostrils",
    "sad": "A combination of raising the inner brows, pulling the lip corners down, and slightly raising the chin",
    "happy": "A combination of raising the cheeks, pulling the lip corners up, and creating crow’s feet around the eyes"
}

emotion_detail_map_CASME3_4class = {
    "negative": "A combination of disgust, fear, anger, and sad expressions, including features such as brow lowering, nose wrinkling, eyelid raising, lips pressed or stretched, and downturned mouth corners",
    "positive": "A combination of raising the cheeks, pulling the lip corners up, and creating crow’s feet around the eyes",
    "surprise": "A combination of raising the inner and outer brows, widening the eyes, and slightly parting the lips",
    "others": "A combination of subtle non-specific movements such as slight brow lowering, creating dimples, or raising the chin"
}

emotion_detail_map_CASME3_3class = {
    "negative": "A combination of disgust, fear, anger, and sad expressions, including features such as brow lowering, nose wrinkling, eyelid raising, lips pressed or stretched, and downturned mouth corners",
    "positive": "A combination of raising the cheeks, pulling the lip corners up, and creating crow’s feet around the eyes",
    "surprise": "A combination of raising the inner and outer brows, widening the eyes, and slightly parting the lips"
}


def load_clip_to_cpu(cfg):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict())

    return model


def safe_tokenize(text, device):
    # 调用 clip.tokenize 得到 tensor，如果返回的不是 tensor，则转为 tensor
    t = clip.tokenize(text)
    if not isinstance(t, torch.Tensor):
        t = torch.tensor(t, dtype=torch.long)
    return t.to(device)


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = torch.float32

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x


class GlobalPromptLearner(nn.Module):
    def __init__(self, clip_model, emotion_detail_map, n_ctx=5, ctx_init="A photo of", class_token_position="end"):
        super().__init__()
        self.dtype = torch.float32
        self.token_embedding = clip_model.token_embedding
        self.n_ctx = n_ctx
        self.class_token_position = class_token_position

        # 使用 emotion_detail_map 的 key 作为类别名称
        classnames = list(emotion_detail_map.keys())
        n_cls = len(classnames)

        # 获取 token_embedding 所在设备
        device_embed = next(self.token_embedding.parameters()).device

        # 初始化可学习的上下文向量
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = safe_tokenize(ctx_init, device_embed)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(self.dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # 如果没有初始化文本，随机初始化
            ctx_vectors = torch.empty(n_ctx, clip_model.ln_final.weight.shape[0], dtype=self.dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        self.ctx = nn.Parameter(ctx_vectors)  # 可训练的上下文

        # 对类别名称预处理
        classnames = [name.replace("_", " ") for name in classnames]

        # 构造完整的prompt模板，添加介词"of"让句子更通顺
        # 例如："A photo of micro-expression of disgust."
        prompts = [prompt_prefix + " micro-expression of " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([safe_tokenize(p, device_embed) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype)

        # 分离各个部分的embedding
        # [SOS] token
        self.register_buffer("token_prefix", embedding[:, :1, :])

        # "micro-expression of" 是固定的部分
        micro_exp_of_tokens = safe_tokenize("micro-expression of", device_embed)
        micro_exp_of_len = micro_exp_of_tokens.shape[1] - 2  # 去掉SOS和EOS
        self.register_buffer("micro_expression_of_tokens",
                             embedding[:, 1 + n_ctx: 1 + n_ctx + micro_exp_of_len, :])

        # 类别名称长度
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        self.name_lens = name_lens

        # 后缀部分（类别名称 + "." + EOS）
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + micro_exp_of_len:, :])

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts

    def forward(self):
        ctx = self.ctx  # 可学习的上下文
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix  # [SOS]
        micro_exp_of = self.micro_expression_of_tokens  # "micro-expression of" (固定)
        suffix = self.token_suffix  # 类别名 + "." + [EOS] (固定)

        if self.class_token_position == "end":
            # 结构：[SOS] + [可学习上下文] + [micro-expression of] + [类别名] + [.]
            prompts = torch.cat([prefix, ctx, micro_exp_of, suffix], dim=1)
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts_list = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                micro_exp_of_i = micro_exp_of[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                # [SOS] + [上下文前半] + [micro-expression of] + [类别] + [上下文后半] + [.]
                prompt = torch.cat([prefix_i, ctx_i_half1, micro_exp_of_i, class_i, ctx_i_half2, suffix_i], dim=1)
                prompts_list.append(prompt)
            prompts = torch.cat(prompts_list, dim=0)
        elif self.class_token_position == "front":
            prompts_list = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                micro_exp_of_i = micro_exp_of[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                # [SOS] + [micro-expression of] + [类别] + [可学习上下文] + [.]
                prompt = torch.cat([prefix_i, micro_exp_of_i, class_i, ctx_i, suffix_i], dim=1)
                prompts_list.append(prompt)
            prompts = torch.cat(prompts_list, dim=0)
        else:
            raise ValueError("Invalid class_token_position")

        return prompts


class LocalPromptLearner(nn.Module):
    def __init__(self, clip_model, emotion_detail_map, n_ctx=5, ctx_init_local="A detailed view of",
                 local_class_token_position="end"):
        super().__init__()
        self.dtype = torch.float32
        self.token_embedding = clip_model.token_embedding
        self.n_ctx = n_ctx
        self.class_token_position = local_class_token_position

        # 使用 emotion_detail_map 的 value 作为细节描述
        details = list(emotion_detail_map.values())
        n_cls = len(details)

        device_embed = next(self.token_embedding.parameters()).device

        # 初始化可学习的上下文向量
        if ctx_init_local:
            ctx_init_local = ctx_init_local.replace("_", " ")
            n_ctx = len(ctx_init_local.split(" "))
            prompt = safe_tokenize(ctx_init_local, device_embed)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(self.dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init_local
        else:
            print("Initializing a generic local context")
            ctx_vectors = torch.empty(n_ctx, clip_model.ln_final.weight.shape[0], dtype=self.dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        self.ctx = nn.Parameter(ctx_vectors)  # 可训练的上下文

        details = [d.replace("_", " ") for d in details]

        # 构造更通顺的prompt模板，添加介词"showing"
        # 例如："A detailed view of micro-expression showing A combination of lowering the brows..."
        prompts = [prompt_prefix + " micro-expression showing " + d + "." for d in details]

        tokenized_prompts = torch.cat([safe_tokenize(p, device_embed) for p in prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(self.dtype)

        # 分离各个部分
        self.register_buffer("token_prefix", embedding[:, :1, :])

        # "micro-expression showing" 是固定的部分
        micro_exp_showing_tokens = safe_tokenize("micro-expression showing", device_embed)
        micro_exp_showing_len = micro_exp_showing_tokens.shape[1] - 2  # 去掉SOS和EOS
        self.register_buffer("micro_expression_showing_tokens",
                             embedding[:, 1 + n_ctx: 1 + n_ctx + micro_exp_showing_len, :])

        # 详细描述长度
        name_lens = [len(_tokenizer.encode(d)) for d in details]
        self.name_lens = name_lens

        # 后缀部分
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + micro_exp_showing_len:, :])

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts

    def forward(self):
        ctx = self.ctx  # 可学习的上下文
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        micro_exp_showing = self.micro_expression_showing_tokens  # "micro-expression showing" (固定)
        suffix = self.token_suffix  # 详细描述 + "." + [EOS] (固定)

        if self.class_token_position == "end":
            prompts = torch.cat([prefix, ctx, micro_exp_showing, suffix], dim=1)
        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts_list = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                micro_exp_showing_i = micro_exp_showing[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                prompt = torch.cat([prefix_i, ctx_i_half1, micro_exp_showing_i, class_i, ctx_i_half2, suffix_i], dim=1)
                prompts_list.append(prompt)
            prompts = torch.cat(prompts_list, dim=0)
        elif self.class_token_position == "front":
            prompts_list = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                micro_exp_showing_i = micro_exp_showing[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                prompt = torch.cat([prefix_i, micro_exp_showing_i, class_i, ctx_i, suffix_i], dim=1)
                prompts_list.append(prompt)
            prompts = torch.cat(prompts_list, dim=0)
        else:
            raise ValueError("Invalid local class_token_position")

        return prompts
