# 项目创建时间：2024/9/2 00:56
"""
 Copyright (c) 2023, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
#import sys
#sys.path.append('/workspace/model/Paint-by-Example-main-2')
#print(sys.path)
import logging

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from torch.nn import functional as F

from lavis.common.registry import registry
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
from lavis.models.blip2_models.blip2 import (
    Blip2Base,
    compute_sim_matrix,
    disabled_train,
)
from lavis.models.blip_models.blip_outputs import BlipOutput, BlipOutputFeatures
from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
from transformers import BertTokenizer
#from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
from ldm.models.TaskFormer.TaskFormer_revise_3 import SemanticProjLayer
#@registry.register_model("blip2")
#@registry.register_model("blip2_feature_extractor")


class Blip2Qformer(Blip2Base):
    """
    BLIP2 first-stage model with Q-former and ViT.
    Supported model types:
        - pretrained: pretrained model with vit-g
        - pretrain_vitL: pretrained model with vit-large
        - coco: fintuned model on coco
    Usage:
        >>> from lavis.models import load_model
        >>> model = load_model("blip2", "pretrain")
    """

    PRETRAINED_MODEL_CONFIG_DICT = {
        "pretrain": "configs/models/blip2/blip2_pretrain.yaml",
        "pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml",
        "coco": "configs/models/blip2/blip2_coco.yaml",
    }

    def __init__(
            self,
            vit_model="eva_clip_g",
            use_vit_out=False, # todo-新增,标注是否使用blip模型使用的视觉提取模块即vit构成的clip
            vit_encoder_dim=1024,  # todo-新增,如果采用外部注入的方式，则改变
            img_size=224,  # 原来是224，我这里改成了32
            drop_path_rate=0,  # 路径丢弃率，用于正则化，防止过拟合
            use_grad_checkpoint=False,  # use_grad_checkpoint: 是否使用梯度检查点技术，以减少内存使用
            vit_precision="fp32", # todo-这里改了一下，原文用了autocast的混合精度，没有gpu操作不了，原来时fp16，训练时需要改回来，减少计算
            freeze_vit=True,
            num_query_token=32,
            cross_attention_freq=2,
            # 这个参数可以控制模型在每个编码器或解码器层中插入交叉注意力层的频率。例如，如果设置为 1，则几乎每个变换器层都会有交叉注意力；如果设置为 2，则每隔一个层插入一次交叉注意力层
            embed_dim=256,
            max_txt_len=32,
    ):
        super().__init__()
        self.vit_encoder_dim = vit_encoder_dim
        self.num_queries = num_query_token
        self.hidden_dim = embed_dim

        self.use_vit_out = use_vit_out

        self.tokenizer = self.init_tokenizer()  # 在原bert tokenizer中添加了一个特殊令牌 bos_token（beginning of sentence），并将其命名为 [DEC]。这个特殊令牌通常用于表示句子的开始
        if use_vit_out: # 如果不使用vit_clip则不进行初始化
            self.visual_encoder, self.ln_vision = self.init_vision_encoder(  # vit_model视觉编码器创建，self.ln_vision——LayerNorm
                vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
            )
            if freeze_vit:  #
                for name, param in self.visual_encoder.named_parameters():
                    param.requires_grad = False
                self.visual_encoder = self.visual_encoder.eval()  # 在评估模式下，某些特定层（如 Dropout 和 BatchNorm）会改变其行为以进行评估，而不是训练
                self.visual_encoder.train = disabled_train  # 将 self.visual_encoder.train 方法替换为 disabled_train 函数，这样无论何时调用 self.visual_encoder.train()，实际上都会调用 disabled_train 函数，而不会改变原来模型的train/eval状态。
                logging.info("freeze vision encoder")
                print(f"QFormer初始化中visual_encoder.num_features:{self.visual_encoder.num_features}")
        else:
            from lavis.models.blip2_models.blip2 import LayerNorm
            self.ln_vision = LayerNorm(1024) # swin-4 layer out dim

        self.proj_layer = SemanticProjLayer(in_dim=768,out_dim=768, hidden_dim=3072)
        #self.swin_proj = SemanticProjLayer(in_dim=1024,out_dim=1024, hidden_dim=2048)

        self.Qformer, self.query_tokens = self.init_Qformer(  # self.visual_encoder.num_features这个好像是每一个图片特征的维度
            num_query_token, self.visual_encoder.num_features if  use_vit_out else vit_encoder_dim, cross_attention_freq
        )

        self.Qformer.resize_token_embeddings(
            len(self.tokenizer))  # 分词器负责将文本字符串转换为令牌（token）序列，而模型则需要一个令牌嵌入层来将这些令牌转换为连续的向量表示。如果词数量变化令牌嵌入层也应当发生变化

        state_dict = self.Qformer.state_dict()  # Qformer 模型当前的所有参数和缓冲区的字典
        for name, param in self.Qformer.named_parameters():  # 如果原始参数名称是 "layer_norm_query"，替换后的名称将是 "layer_norm"。
            if "_query" in name:  # todo-不知道有什么用？
                key_orig = name.replace("_query", "")
                param.data.copy_(state_dict[key_orig])
            

        self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
        self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)

        self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2)  # 用于图像文本匹配任务，输出一个二分类的结果，例如匹配和不匹配的概率值

        self.temp = nn.Parameter(0.07 * torch.ones([]))  # 温度参数，通常用于调整 softmax 函数的平滑度

        self.max_txt_len = max_txt_len

    def init_tokenizer(cls):
        #tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        tokenizer = BertTokenizer.from_pretrained("/nvfile-heatstorage/AIGC_H100/basemodel_exp/ckpts/zhangangang/vton/huggingface_models/bert-base-uncased") # 服务器上放置的地址
        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
        return tokenizer

    def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
        #encoder_config = BertConfig.from_pretrained("bert-base-uncased")
        encoder_config = BertConfig.from_pretrained("/nvfile-heatstorage/AIGC_H100/basemodel_exp/ckpts/zhangangang/vton/huggingface_models/bert-base-uncased") # 服务器上放置的地址

        #encoder_config = BertConfig.from_pretrained("/home/u2023110683/.cache/huggingface/hub/transformers/bert-base-uncased")# 服务器上放置的地址
        encoder_config.encoder_width = vision_width
        # insert cross-attention layer every other block
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        ##Qformer = BertLMHeadModel.from_pretrained("bert-base-uncased", config=encoder_config)
        Qformer = BertLMHeadModel.from_pretrained("/nvfile-heatstorage/AIGC_H100/basemodel_exp/ckpts/zhangangang/vton/huggingface_models/bert-base-uncased", config=encoder_config) # 服务器上放置的地址
        #只有自注意力构成的
        #Qformer = BertLMHeadModel.from_pretrained("/nvfile-heatstorage/AIGC_H100/zhangangang/vton/huggingface_models/bert-base-uncased") # 服务器上放置的地址

        #Qformer = BertLMHeadModel.from_pretrained("/home/u2023110683/.cache/huggingface/hub/transformers/bert-base-uncased", config=encoder_config) # 服务器上放置的地址
        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
        return Qformer, query_tokens


    def forward(self, samples):
        image = samples["image"]
        text = samples["text_input"]

        image_embeds = self.ln_vision(
            self.visual_encoder(image))  # self.ln_vision——LayerNorm，(batch_size, num_patches, embed_dim)
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)  # 扩展每个图像对应一组query

        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            use_cache=True,
            return_dict=True,
        )

        image_feats = F.normalize(  # 图片特征用来做后面的对比学习的
            self.vision_proj(query_output.last_hidden_state), dim=-1
        )

        text_tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_txt_len,
            return_tensors="pt",
        ).to(image.device)
        text_output = self.Qformer.bert(
            text_tokens.input_ids,
            attention_mask=text_tokens.attention_mask,
            return_dict=True,
        )
        text_feat = F.normalize(
            # 文本特征用来做对比学习，已经与图片特征投影到同一个空间。正常情况下是两部分拼接做为输入，但因为对比学习需要两部分都互相不能看见，因此可以设计如论文中的atten_mask,但直接做两次输出也可以
            self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1  # 文本embeding中的cls token代表文本整体信息，用这个去做对比学习
        )

        ###============== Image-text Contrastive ===================###
        # todo-这里的负样本的信息感觉像其他gpu上的内容？
        image_feats_all = concat_all_gather(  # 分布式训练环境中整合所有生成信息
            image_feats
        )  # [batch_size*num_gpu, num_query_tokens, embed_dim]
        text_feat_all = concat_all_gather(text_feat)  # [batch_size*num_gpu, embed_dim]

        # 矩阵乘法的维度匹配规则要求第一个张量的最后一个维度与第二个张量的第一个维度相同
        sim_q2t = torch.matmul(  # [batch_size, 1, num_query_tokens, embed_dim]，  [batch_size*num_gpu, embed_dim, 1]
            image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
        ).squeeze()
        # [batch_size, batch_size*num_gpu, num_query_tokens]

        # image-text similarity: aggregate across all query tokens
        sim_i2t, _ = sim_q2t.max(-1)  # [batch_size, batch_size*num_gpu]
        sim_i2t = sim_i2t / self.temp

        # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]
        sim_t2q = torch.matmul(  # [batch_size,1,1, embed_dim]  ,[batch_size*num_gpu，embed_dim,num_query_tokens]
            text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
        ).squeeze()

        # text-image similarity: aggregate across all query tokens
        sim_t2i, _ = sim_t2q.max(-1)
        sim_t2i = sim_t2i / self.temp  # [batch_size, batch_size*num_gpu]

        rank = dist.get_rank()  # 获取当前进程在分布式训练中的 rank（序号
        bs = image.size(0)  # 获取批次大小:
        targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(  # todo-这里为什么要这么取？？创建target维度为[b,]
            image.device
        )

        if "image_id" in samples.keys():  # coco retrieval finetuning ，在有图像 ID 的情况下，使用图像 ID 来确定正样本
            image_ids = samples["image_id"].view(-1, 1)  # [b, 1]
            image_ids_all = concat_all_gather(image_ids)  # [b * num_gpu, 1]， P进程数
            pos_idx = torch.eq(image_ids,
                               image_ids_all.t()).float()  # 通过比较 image_ids 和 image_ids_all 的转置，找出相同的图像 ID 对（即正样本对），[b, b * num_gpu]
            sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)  # [b, b * num_gpu]
            sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(
                1)  # [b, b * num_gpu]

            loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()  # 对数似然损失
            loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
            loss_itc = (loss_t2i + loss_i2t) / 2
        else:
            loss_itc = (
                               F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)  # 带标签平滑的交叉熵损失
                               + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
                       ) / 2

        ###============== Image-text Matching ===================###
        # text_id数据集由pos,pos,neg,第一个维度为3*b构成，image_embed就由 pos, neg, pos,一个维度同样是3*b构成,query_num的也是3*b组，由query经过
        # 一个线性层进行分类，只有第一个batch内标签是1因为text_pos-imgae_pos，其他2*batch都是0，由交叉熵给出损失
        # 选择负样本的方式有些技巧这里？？？
        text_input_ids_world = concat_all_gather(text_tokens.input_ids)
        text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
        image_embeds_world = all_gather_with_grad(image_embeds)
        with torch.no_grad():
            if "image_id" in samples.keys():
                mask = torch.eq(image_ids, image_ids_all.t())  # 找出正样本
                sim_t2i.masked_fill_(mask, -10000)  # 用大负数替换正样本对地方，使正样本softmax后趋于0，减少损失的影响
                sim_i2t.masked_fill_(mask, -10000)
            else:
                sim_t2i[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)  # 填充对角线来mask相似度矩阵中的正样本对
                sim_i2t[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)

            weights_t2i = F.softmax(sim_t2i, dim=1)
            weights_i2t = F.softmax(sim_i2t, dim=1)

        # select a negative image for each text ,利用之前的weights_t2i 、weights_i2t采样负样本，这里有些技巧
        image_embeds_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_t2i[b],
                                        1).item()  # 根据权重分布进行抽样，1 表示从分布中抽取一个样本，.item() 将得到的索引从张量转换为 Python 标量。
            image_embeds_neg.append(image_embeds_world[neg_idx])
        image_embeds_neg = torch.stack(image_embeds_neg, dim=0)  # -> [b, embed_dim]

        # select a negative text for each image
        text_ids_neg = []
        text_atts_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_i2t[b], 1).item()
            text_ids_neg.append(text_input_ids_world[neg_idx])
            text_atts_neg.append(text_attention_mask_world[neg_idx])

        text_ids_neg = torch.stack(text_ids_neg, dim=0)  # -> [b, embed_dim]
        text_atts_neg = torch.stack(text_atts_neg, dim=0)  # -> [b, embed_dim]

        text_ids_all = torch.cat(
            [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
        )  # pos, pos, neg
        text_atts_all = torch.cat(
            [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
            dim=0,
        )

        query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)  # ->[3*b,query_num, embed_dim]
        query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
            image.device
        )  # -> [3*b, query_num], 这时attention map是全1
        attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)  # -> [3*b, query_num + text_token_num]

        image_embeds_all = torch.cat(
            [image_embeds, image_embeds_neg, image_embeds], dim=0
        )  # pos, neg, pos
        image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
            image.device  # [3*b, patch_num]
        )

        output_itm = self.Qformer.bert(
            text_ids_all,  # text的输入 [3*b, seq_length]。
            query_embeds=query_tokens_itm,  # query的输入，为什么两部分不加在一起输入？？？？[3*b, query_num, embed_dim]
            attention_mask=attention_mask_all,  # 这里的mask应该就是query处的mask，[3*b, query_num + text_token_num]
            encoder_hidden_states=image_embeds_all,  # [3*b, patch_num,embed_dim]
            encoder_attention_mask=image_atts_all,  # todo- 这里图片的mask相当于encoder输出的结果的mask有什么用？？ #[3*b, patch_num]
            return_dict=True,
        )

        vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]  # -> [3*b, query_num, embed_dim]
        vl_output = self.itm_head(vl_embeddings)  # [3*b, query_num, 2]
        logits = vl_output.mean(dim=1)  # -> [3*b, 2]

        itm_labels = torch.cat(
            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
            dim=0,
        ).to(image.device)
        loss_itm = F.cross_entropy(logits, itm_labels)

        ##================= Image Captioning ========================##
        decoder_input_ids = text_tokens.input_ids.clone()  # 对可变对象进行赋值操作时，改变decoder_input_ids 也会改变原来的数，所以这里进行克隆
        decoder_input_ids[:, 0] = self.tokenizer.bos_token_id  # 这里原来第一个cls token被替换成了开始BOS token，
        labels = decoder_input_ids.masked_fill(  # 将所有填充令牌（PAD, padding token）的 ID 替换为 -100。这通常用于在损失计算中忽略填充令牌。
            decoder_input_ids == self.tokenizer.pad_token_id, -100
        )

        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
            image.device
        )  # query的atten_mask全1
        attention_mask = torch.cat([query_atts, text_tokens.attention_mask],
                                   dim=1)  # [b, query_num + 1+ text_num]，但这里text_mask不对吧，因为tokenize的mask只对pad进行了mask，是模型选择解码模式时再进行单向mask吗？？
        lm_output = self.Qformer(  # 只有这里用到了Qformer的forward部分，其他都是简单的用Qformer.bert
            decoder_input_ids,
            attention_mask=attention_mask,
            past_key_values=query_output.past_key_values,
            return_dict=True,
            labels=labels,
        )

        loss_lm = lm_output.loss

        return BlipOutput(
            loss=loss_itc + loss_itm + loss_lm,
            loss_itc=loss_itc,
            loss_itm=loss_itm,
            loss_lm=loss_lm,
        )

    @torch.no_grad()
    def generate(
            self,
            samples,
            use_nucleus_sampling=False,
            num_beams=3,
            max_length=30,
            min_length=10,
            top_p=0.9,
            repetition_penalty=1.0,
    ):
        """
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
            num_beams (int): Number of beams for beam search. 1 means no beam search.
            max_length (int): The maximum length of the sequence to be generated.
            min_length (int): The minimum length of the sequence to be generated.
            top_p (float): The cumulative probability for nucleus sampling.
            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
            num_captions (int): Number of captions to be generated for each image.
        Returns:
            captions (list): A list of strings of length batch_size * num_captions.
        """
        image = samples["image"]
        image_embeds = self.ln_vision(self.visual_encoder(image))

        if not use_nucleus_sampling:
            image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
        else:
            num_beams = 1
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        model_kwargs = {
            "encoder_hidden_states": image_embeds,
            "encoder_attention_mask": image_atts,
        }

        input_ids = (
            torch.LongTensor(image.size(0), 1)
            .fill_(self.tokenizer.bos_token_id)
            .to(image.device)
        )
        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)

        outputs = self.Qformer.generate(
            input_ids=input_ids,
            query_embeds=query_tokens,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            do_sample=use_nucleus_sampling,
            top_p=top_p,
            eos_token_id=self.tokenizer.sep_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            **model_kwargs
        )
        captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return captions

    def forward_image(self, image):
        image_embeds = self.ln_vision(self.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)

        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )
        return query_output.last_hidden_state, image_embeds

    def forward_text(self, text_tokens):
        text_output = self.Qformer.bert(
            text_tokens.input_ids,
            attention_mask=text_tokens.attention_mask,
            return_dict=True,
        )
        return text_output.last_hidden_state[:, 0, :]

    def compute_itm(self, image_inputs, text_ids, text_atts):
        image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to(
            image_inputs.device
        )
        query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1)
        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
            image_inputs.device
        )
        attention_mask = torch.cat([query_atts, text_atts], dim=1)
        output_itm = self.Qformer.bert(
            text_ids,
            query_embeds=query_tokens,
            attention_mask=attention_mask,
            encoder_hidden_states=image_inputs,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )
        vl_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :]
        itm_logit = self.itm_head(vl_embeddings)
        itm_logit = itm_logit[:, :, 1].mean(dim=1)
        return itm_logit

    #@torch.no_grad()
    def extract_features(self, samples, mode="multimodal",image_embeds_frozen=None,require_QFormer_text_embedding=False): # todo-新增image_embeds_frozen作为可以传入的提取图片特征,因为本文代码主要用到了这部分所以修改了初始化和这部分函数即可
        """
        Extract features for multimodal or unimodal samples.
        Args:
            samples (dict): A dictionary of samples, containing the following keys:
                - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.
                    Raw images should be preprocessed before being passed to feature extractor.
                - text_input (list): A list of strings containing the text, length B.
            mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image".
                If "multimodal", return image features and multimodal features;
                if "text", return text features;
                if "image", return image features.
                Default: "multimodal".
        Returns:
            BlipOutputFeatures: A BlipOutputFeatures object containing the features.
                See lavis/models/blip_models/blip_outputs.py for more details.
        """
        image = samples.get("image")
        caption = samples.get("text_input")
        ### print(f'Qformer text_input:{caption}')

        # assert mode is one of "image", "text", "multimodal"
        assert mode in [
            "image",
            "text",
            "multimodal",
        ], "mode must be one of 'image', 'text', 'multimodal'"

        # initalize output
        image_embeds, text_embeds, multimodal_embeds = None, None, None
        image_features, text_features = None, None

        if mode == "image":
            assert (
                    image is not None
            ), "Image is not provided for mode 'image' or 'multimodal'"
            # return query features
            with self.maybe_autocast():
                image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
            image_embeds_frozen = image_embeds_frozen.float()
            image_atts = torch.ones(
                image_embeds_frozen.size()[:-1], dtype=torch.long
            ).to(self.device)
            query_tokens = self.query_tokens.expand(
                image_embeds_frozen.shape[0], -1, -1
            )

            query_output = self.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds_frozen,
                encoder_attention_mask=image_atts,
                return_dict=True,
            )
            image_embeds = query_output.last_hidden_state
            image_features = F.normalize(self.vision_proj(image_embeds), dim=-1)

        elif mode == "text":
            assert (
                    caption is not None
            ), "text input is None for mode 'text' or 'multimodal'"

            # return text features
            text = self.tokenizer(caption, return_tensors="pt", padding=True).to(
                self.device
            )

            text_output = self.Qformer.bert(
                text.input_ids,
                attention_mask=text.attention_mask,
                return_dict=True,
            )
            text_embeds = text_output.last_hidden_state
            text_features = self.text_proj(text_embeds)
            text_features = F.normalize(text_features, dim=-1)

        elif mode == "multimodal":
            # return multimodel query features
            #————————————————————
            if  self.use_vit_out: # todo-新增判断是否用vit抽取视觉特征
                with self.maybe_autocast(): # 返回一个上下文管理器，该管理器根据设备类型（CPU 或 GPU）决定是否启用自动混合精度
                    image_embeds_frozen = self.ln_vision(self.visual_encoder(image))
            else: 
                assert image_embeds_frozen is not None, "At least one of image and image_embeds_frozen should be provided"
                #image_embeds_frozen = self.swin_proj(self.ln_vision(image_embeds_frozen.float()))
                image_embeds_frozen = self.ln_vision(image_embeds_frozen.float())


            image_atts = torch.ones(
                image_embeds_frozen.size()[:-1], dtype=torch.long
            ).to(self.device)
            query_tokens = self.query_tokens.expand( # query数量扩展到batch的层级，即batch里一个数据对应一组query
                image_embeds_frozen.shape[0], -1, -1
            )
            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
                self.device
            )
            text = self.tokenizer(caption, return_tensors="pt", padding=True).to(
                self.device
            )
            #————————————————
            if require_QFormer_text_embedding: # todo-新增，去和mask的语义特征做loss，且是利用text_features去做，维度是256
                assert (
                        caption is not None
                ), "text input is None for mode 'multimodal'"
                """
                text_tokens = self.tokenizer(
                    caption,
                    padding="max_length",
                    truncation=True,
                    max_length=self.max_txt_len,
                    return_tensors="pt",
                ).to(self.device)
                """
                text_output = self.Qformer.bert(
                    text.input_ids,
                    attention_mask=text.attention_mask,
                    return_dict=True,
                )
                text_feat = F.normalize(
                    # 文本特征用来做对比学习，已经与图片特征投影到同一个空间。正常情况下是两部分拼接做为输入，但因为对比学习需要两部分都互相不能看见，因此可以设计如论文中的atten_mask,但直接做两次输出也可以
                    text_output.last_hidden_state[:, 0, :], dim=-1
                    # 文本embeding中的cls token代表文本整体信息，用这个去做对比学习
                )
                text_embeds = text_feat
                #text_features = text_feat

            #---------------

            ### print(f"QFormer:: query_atts:{query_atts.size()}, text.attention_mask:{text.attention_mask.size()}") # todo-测试

            attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
            #print(f"image_embeds_frozen: {image_embeds_frozen.size()}")
            #-------------------------
            # 这里不知道为啥bert里面的qkv都是flaot32，而输入已经是float16，所以报错，我这里增加转换成flaot16.但是原来不用clip-l的时候却没有这个问题
            # 不知道为啥？
            with self.maybe_autocast():
                output = self.Qformer.bert(
                    text.input_ids,
                    query_embeds=query_tokens,
                    attention_mask=attention_mask,
                    encoder_hidden_states=image_embeds_frozen,
                    encoder_attention_mask=image_atts,
                    return_dict=True,
                )

            multimodal_embeds = self.proj_layer(output.last_hidden_state[:, : query_tokens.size(1), :])

            #print(f"output.requires_grad:{output.last_hidden_state.requires_grad}")

        return BlipOutputFeatures(
            image_embeds=image_embeds_frozen, # 视觉编码器输出的特征
            image_embeds_proj=image_features, 
            text_embeds=text_embeds,
            text_embeds_proj=text_features,
            multimodal_embeds=multimodal_embeds,
        )

    @classmethod
    def from_config(cls, cfg):
        vit_model = cfg.get("vit_model", "eva_clip_g")
        img_size = cfg.get("img_size")
        num_query_token = cfg.get("num_query_token")
        cross_attention_freq = cfg.get("cross_attention_freq", 2)

        drop_path_rate = cfg.get("drop_path_rate", 0)
        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
        vit_precision = cfg.get("vit_precision", "fp16")
        freeze_vit = cfg.get("freeze_vit", True)
        use_vit_out = cfg.get("use_vit_out", False) # "use_vit_out:False"
        vit_encoder_dim = cfg.get("vit_encoder_dim",1024)
        

        max_txt_len = cfg.get("max_txt_len", 32)

        model = cls(
            vit_model=vit_model,
            img_size=img_size,
            drop_path_rate=drop_path_rate,
            use_grad_checkpoint=use_grad_checkpoint,
            vit_precision=vit_precision,
            freeze_vit=freeze_vit,
            num_query_token=num_query_token,
            cross_attention_freq=cross_attention_freq,
            max_txt_len=max_txt_len,
            use_vit_out = use_vit_out,
            vit_encoder_dim=vit_encoder_dim
        )
        #model.load_checkpoint_from_config(cfg)

        return model

    def compute_sim_matrix(self, data_loader, task_cfg):
        """
        Compute similarity i2t, t2i matrix for the given data loader.
        """
        k_test = task_cfg.k_test

        return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)


def build_QFormer_encoding(config): # todo-这里还没写完全，在考虑可以利用配置文件的形式导入，就不用都是命令行的形式，主要用extract_features提取到query对应的特征
    model = Blip2Qformer.from_config(config)
    pretrained_path = config.pop("pretrained_path", None)
    if pretrained_path:
        proj_state = {}
        state_dict = torch.load(pretrained_path, map_location="cpu")
        # qformer keys: Qformer.bert.encoder.layer.1.attention.self.key.weight
        # ckpt keys: text_model.bert.encoder.layer.1.attention.attention.key.weight
        for k in list(state_dict.keys()):
            if "attention.query" in k:
                state_dict[k.replace("attention.query", "self.query")] = state_dict.pop(k)
            elif "attention.key" in k:
                state_dict[k.replace("attention.key", "self.key")] = state_dict.pop(k)
            elif "attention.value" in k:
                state_dict[k.replace("attention.value", "self.value")] = state_dict.pop(k)
            elif "proj_layer" in k:
                proj_state[k] = state_dict[k]
                #print(k, state_dict[k].shape)
        msg1 = model.Qformer.bert.load_state_dict(state_dict, strict=False)
        assert len(msg1.missing_keys) == 0,"QFormer state_dict cannot match"
        print(f"Qformer load from {pretrained_path}, unexpected_keys:{msg1.unexpected_keys}")
        assert state_dict["query_tokens"] is not None, "state_dict[query_tokens] is none"
        if state_dict["query_tokens"].shape[1] != model.num_queries:
            state_dict["query_tokens"] = state_dict["query_tokens"].repeat(1,int(model.num_queries/state_dict["query_tokens"].shape[1]),1)
        msg2 = model.query_tokens.data.copy_(state_dict["query_tokens"])
        print(f"load Qformer query msg:{msg2}")
        msg3 = model.load_state_dict(proj_state,strict=False)
        assert len(msg3.unexpected_keys) == 0,"QFormer proj_layer cannot match"

        #print(f"missing keys:{msg3.missing_keys},unexpected_keys:{msg3.unexpected_keys}")
        print("achieve loading QFormer weight from pretrain_path successfully !")
        
        torch.cuda.empty_cache()

    return model
    
    
    
    #return Blip2Qformer(
    #    **config["QFormer"]
        #vit_model=config["QFormer"]["vit_model"],
        #num_query_token=config["QFormer"]["qformer_num_query_token"],
        #cross_attention_freq=config["QFormer"]["qformer_cross_attention_freq"],
        #use_vit_out=config["QFormer"]["use_vit_out"],
        #vit_encoder_dim =config["QFormer"]["vit_encoder_dim"],
        #img_size = config["QFormer"]["image_size"]
    #)

if __name__ == "__main__":
    
    import yaml
    with open('./configs/configs2_5_viton_revise_3.yaml', 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)['model']['params']["QFormer"]

    model = build_QFormer_encoding(config)
    #for name, param in model.named_parameters():
    #    print(name, param.shape)
    """
    image_embedding= torch.rand([1,64,1024])
    text = "12"
    x = {"image":None, "text_input":text}
    out = model.extract_features(x,image_embeds_frozen=image_embedding)
    print(out.multimodal_embeds.requires_grad)
    """
    #print(model)
    #for name, param in model.named_parameters():
    #    print(name, param.shape)
    

    


