import torch
import torch.nn as nn

from . import clip
from .simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()


class cfgc(object):
    backbonename = 'ViT-B/16'
    NCTX = 16
    CTXINIT = ''
    CSC = False
    CLASS_TOKEN_POSITION = 'end'
    
def load_clip_to_cpu(cfg):
    backbone_name = cfg.backbonename#'ViT-B/16'
    url = clip._MODELS[backbone_name]
    # 模型参数下载，并返回模型文件存储位置
    model_path = clip._download(url)

    try:
        # loading JIT archive，JIT存储格式的模型加载到CPU，model为CLIP全部模型参数
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")
    # model.state_dict()其实返回的是一个OrderDict，存储了网络结构的名字和对应的参数
    # 这里调用 build_model 
    model = clip.build_model(state_dict or model.state_dict())

    return model


class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding # torch.Size([77, 512])
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        # prompts torch.Size([2, 77, 512]) 加上位置向量 tokenized_prompts[2, 77]
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND torch.Size([77, 2, 512])
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD ([2, 77, 512])
        x = self.ln_final(x).type(self.dtype) # ([2, 77, 512]) ln是layer normalization

        # x.shape = [batch_size, n_ctx, transformer.width]
        # argmax(dim=-1)返回最后一维度最大值的序号 从eot嵌入中获取特征（eot_token是每个序列中的最高数字）
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection   # [2,512] X [512,512]
        return x

# 针对prompt的特有学习器
# clip_model不更新参数，都不是self.clip_model，只有self.ctx是待优化的
class PromptLearner(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames) # 类别数量2

        n_ctx = cfg.NCTX # number of context vectors 上下文向量的数量16
        ctx_init = cfg.CTXINIT # 用来做什么的
        dtype = clip_model.dtype # 规定后续的向量设计按照固定的数据类型
        ctx_dim = clip_model.ln_final.weight.shape[0] # 512
        # clip_imsize = clip_model.visual.input_resolution
        # cfg_imsize = clip_imsize
        # assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if ctx_init:  # 字符串string直接转换成Boolean类型时，除了空字符串为false外，其余都是true；字符串string和number类型做比较的时候，string会转化成number
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            # random initialization
            if cfg.CSC: # 类别
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) # 类别数量，prompts的模板占位符，
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) #(16, 512)
            nn.init.normal_(ctx_vectors, std=0.02) # 从给定均值和标准差的正态分布N(mean=0, std=0.02)中生成值，填充输入的张量或变量，0.02意味着生成的绝大部分值很靠近0
            prompt_prefix = " ".join(["X"] * n_ctx)#'X X X X X X X X X X X X X X X X' prompt的前缀占位符

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")
        device = clip_model.token_embedding.weight.device # 类型为torch.device，str值为CPU
        # 我们知道在ViT中，positonal embedding和class token是两个需要随着网络训练学习的参数，但是它们又不属于FC、MLP、MSA等运算的参数，在这时，就可以用nn.Parameter()来将这个随机初始化的Tensor注册为可学习的参数Parameter。
        # 因此此处ctx是一个需要优化的参数，因此采用一个torch参数的初始化方式，根据参数大小设定
        self.ctx = nn.Parameter(ctx_vectors).to(device)  #待优化的上下文参数 torch.Size([16, 512])
        classnames = [name.replace("_", " ") for name in classnames]#['real', 'fake'] 去除类名称中的下划线，为了加入prompt模板？
        name_lens = [len(_tokenizer.encode(name)) for name in classnames] #[1, 1] 
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        # ['X X X X X X X X X X X X X X X X real.', 'X X X X X X X X X X X X X X X X fake.']
        #文本编码
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        # tokenized_prompts得到 不同分类的句子的编码，空填0，得到torch.Size([2, 77])
        # token_embedding找到每个词的token编码对应的512维向量
        
        #如果通过clip的encode_text方法，是直接将text(=tokenized_prompts)经过token_embedding变成向量
        #这里改动的是变成向量后，将填充部分替换成可学习的prompt向量，其余进入clip文本transformer的操作不变
        #TextEncoder的forward方法与encode_text方法的差别
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) # torch.Size([2, 77, 512]) clip模型在此处利用预训练模型编码 1位的prefix，16位的prompts，60的后缀(包含了标签)

        # These token vectors will be saved when in save_model(), but they should be ignored in load_model() as we want to use those computed using the current class names
        # 把token_prefix前缀的向量  token_suffix后缀的向量
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS#torch.Size([2, 1, 512])
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS 把非n_ctx要求（前面都是X）的向量长度即为后缀 #torch.Size([2, 60, 512])

        self.n_cls = n_cls #类别数 2
        self.n_ctx = n_ctx # 16 prompt模板前缀的长度
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = cfg.CLASS_TOKEN_POSITION #'end' 表示类别的单词的位置，在句子的尾部

    def forward(self):
        ctx = self.ctx # torch.Size([16, 512])
        if ctx.dim() == 2: # 维度是2
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) # 按类别数量 升维度并扩展torch.Size([2, 16, 512])，可以理解为给每个类别复制一份上下文
        #否则相当于 CSC 每个类别单独训练一个ctx向量
        prefix = self.token_prefix # torch.Size([2, 1, 512])
        suffix = self.token_suffix # torch.Size([2, 60, 512])

        if self.class_token_position == "end":
            prompts = torch.cat(
                [   prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [   prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts # torch.Size([2, 77, 512])附带一个可学习参数的prompt，会被优化

