# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
import logging
import fvcore.nn.weight_init as weight_init
from typing import Optional
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from transformers.activations import QuickGELUActivation as QuickGELU
import pickle
#from ldm.models.transformer_decoder.config import configurable,get_cfg
from ldm.models.TaskFormer.layer import Conv2d
from einops import rearrange

from ldm.models.TaskFormer.position_encoding import PositionEmbeddingSine
#from ldm.models.transformer_decoder.maskformer_transformer_decoder import TRANSFORMER_DECODER_REGISTRY
from lavis.models.blip2_models.blip2 import LayerNorm
class SemanticProjLayer(nn.Module): # todo-后面改一下，将QFormer信息注入到TaskFormer里，即Qformer_dim -> Task_dim
    """
    blip_dimffusion是用这个将提取的subject embedding映射入text_embding那里，其实本质维度还是 768 -> 3072 -> 768
    """
    def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
        super().__init__()

        # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
        self.dense1 = nn.Linear(in_dim, hidden_dim)
        self.act_fn = QuickGELU()
        self.dense2 = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(drop_p)

        self.proj = nn.Linear(in_dim, out_dim)


        self.LayerNorm = nn.LayerNorm(in_dim, eps=eps)

    def forward(self, x):
        x_in = x

        x = self.LayerNorm(x)
        x = self.dropout(self.dense2(self.act_fn(self.dense1(x))))
        x_in = self.proj(x_in)

        x = x + x_in

        return x
    
class SpatialProjLayer(nn.Module):
    """
    将空间特征的通道维度转换到另一个维度。
    """
    def __init__(self, in_channels, out_channels, hidden_channels, drop_p=0.1, eps=1e-12):
        super().__init__()

        # Conv1 -> Act -> Conv2 -> Drop -> Res -> Norm
        self.conv1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1, stride=1, padding=0)
        self.act_fn = QuickGELU()
        self.conv2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.dropout = nn.Dropout(drop_p)

        self.InstanceNorm = nn.InstanceNorm2d(num_features=in_channels, eps=eps)  # 这里的归一化函数换了一下

        self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x_in = x
        x = self.InstanceNorm(x)
        x = self.dropout(self.conv2(self.act_fn(self.conv1(x))))
        x_in = self.proj(x_in)

        x = x + x_in

        return x


class SemanticFeatureToText(nn.Module):
    def __init__(self, semantic_hidden_dim, QFormer_text_dim):
        super().__init__()
        # 定义序列化的层
        self.layers = nn.Sequential(
            nn.Linear(semantic_hidden_dim, semantic_hidden_dim),
            nn.ReLU(),
            nn.Linear(semantic_hidden_dim, QFormer_text_dim),
            nn.LayerNorm(QFormer_text_dim)  # 假设这是对最后一个维度的归一化
        )

    def forward(self, x):
        # 通过层序列执行前向传播
        return self.layers(x)
class SelfAttentionLayer(nn.Module):

    def __init__(self, d_model, nhead, dropout=0.0,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

        self._reset_parameters() # 权重初始化，一般不适用bias的初始化
    
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt,
                     tgt_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)

        return tgt

    def forward_pre(self, tgt,
                    tgt_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        tgt2 = self.norm(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos) # todo- q、k加上位置嵌入，为什么v不加
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, # todo- 这里的attn_mask、key_padding_mask作用可以深入再看看
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        
        return tgt

    def forward(self, tgt,
                tgt_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(tgt, tgt_mask,
                                    tgt_key_padding_mask, query_pos)
        return self.forward_post(tgt, tgt_mask,
                                 tgt_key_padding_mask, query_pos)


class CrossAttentionLayer(nn.Module):

    def __init__(self, d_model, nhead, dropout=0.0,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

        self._reset_parameters()
    
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt, memory,
                     memory_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), # todo- memory产生KV，tgt应该就是query
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask, # todo- 这里attn_mask、key_padding_mask具体什么样
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)
        
        return tgt

    def forward_pre(self, tgt, memory,
                    memory_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        tgt2 = self.norm(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)

        return tgt

    def forward(self, tgt, memory,
                memory_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(tgt, memory, memory_mask,
                                    memory_key_padding_mask, pos, query_pos)
        return self.forward_post(tgt, memory, memory_mask,
                                 memory_key_padding_mask, pos, query_pos)


class FFNLayer(nn.Module):

    def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
                 activation="relu", normalize_before=False):
        super().__init__()
        # Implementation of Feedforward model
        # linear1 -> act -> dropout -> linear2 + 残差
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm = nn.LayerNorm(d_model)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

        self._reset_parameters()
    
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt):
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)
        return tgt

    def forward_pre(self, tgt):
        tgt2 = self.norm(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout(tgt2)
        return tgt

    def forward(self, tgt):
        if self.normalize_before:
            return self.forward_pre(tgt)
        return self.forward_post(tgt)


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


#@TRANSFORMER_DECODER_REGISTRY.register()
class MultiScaleTaskTransformerDecoder(nn.Module):

    #_version = 2
    """
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ): # 加载键值对，将原来的键static_query替换为键query_feat
        version = local_metadata.get("version", None)
        if version is None or version < 2:
            # Do not warn if train from scratch
            scratch = True
            logger = logging.getLogger(__name__)
            for k in list(state_dict.keys()):
                newk = k
                if "static_query" in k:
                    newk = k.replace("static_query", "query_feat")
                if newk != k:
                    state_dict[newk] = state_dict[k]
                    del state_dict[k]
                    scratch = False

            if not scratch:
                logger.warning(
                    f"Weight format of {self.__class__.__name__} have changed! "
                    "Please upgrade your models. Applying automatic conversion now ..."
                )
    """
    #@configurable
    def __init__(
        self,
        in_channels, # todo-这是输入此模型的图片embedding的维度即通道数，如果和此模型需要的hidden_dim相同则不需要proj层投影，这里进行更改成列表，因为unet三层输入维度肯定不同
        #mask_classification=False, # 原来是True
        *,
        #num_classes: int,
        hidden_dim: int,
        num_queries: int,
        nheads: int,
        dim_feedforward: int,
        dec_layers: int,
        pre_norm: bool,
        mask_dim: int,
        enforce_input_project: bool,
        #semantic_feature_start:int , #= 62,
        #img_spatial_feature_start:int, #=20
        mask_q_num: int,
        pose_q_num:int,
        predict_attention_mask:bool,
        QFormer_num_queries:int,
        use_crossattn_and_feature: bool,
        crossattn_mask_threshold: float,
        QFormer_detach: bool,
    ):
        """
        NOTE: this interface is experimental.
        Args:
            in_channels: channels of the input features
            mask_classification: whether to add mask classifier or not
            num_classes: number of classes
            hidden_dim: Transformer feature dimension
            num_queries: number of queries
            nheads: number of heads
            dim_feedforward: feature dimension in feedforward network
            enc_layers: number of Transformer encoder layers
            dec_layers: number of Transformer decoder layers
            pre_norm: whether to use pre-LayerNorm or not
            mask_dim: mask feature dimension
            enforce_input_project: add input project 1x1 conv even if input
                channels and hidden dim is identical
        """
        super().__init__()

        #assert mask_classification, "Only support mask classification model"
        #self.mask_classification = mask_classification

        # positional encoding
        N_steps = hidden_dim // 2
        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) # PositionEmbeddingSine是一种常用于Transformer模型中的位置编码方法，它使用正弦和余弦函数的组合来生成每个位置的编码向量

        self.hidden_dim = hidden_dim
        # define Transformer decoder here
        self.num_heads = nheads
        self.num_layers = dec_layers
        self.predict_attention_mask = predict_attention_mask

        self.transformer_self_attention_layers = nn.ModuleList()
        self.transformer_cross_attention_layers = nn.ModuleList()
        self.transformer_ffn_layers = nn.ModuleList()

        for _ in range(self.num_layers):
            self.transformer_self_attention_layers.append(
                SelfAttentionLayer(
                    d_model=hidden_dim,
                    nhead=nheads,
                    dropout=0.0,
                    normalize_before=pre_norm,
                )
            )

            self.transformer_cross_attention_layers.append(
                CrossAttentionLayer(
                    d_model=hidden_dim,
                    nhead=nheads,
                    dropout=0.0,
                    normalize_before=pre_norm,
                )
            )

            self.transformer_ffn_layers.append(
                FFNLayer(
                    d_model=hidden_dim,
                    dim_feedforward=dim_feedforward,
                    dropout=0.0,
                    normalize_before=pre_norm,
                )
            )

        self.decoder_norm = nn.LayerNorm(hidden_dim)

        self.num_queries = num_queries
        # learnable query features
        self.query_feat = nn.Embedding(num_queries, hidden_dim) # todo- 这部分好像作为query的特征
        # learnable query p.e.
        self.query_embed = nn.Embedding(num_queries, hidden_dim) # todo- 这部分作为query的位置信息
        #self.reconstructed_embeddings  = nn.Embedding(num_queries, hidden_dim)
        self.pose_q_num = pose_q_num
        self.mask_q_num = mask_q_num
        """
        self.register_buffer('reconstructed_embeddings', torch.zeros(num_queries, hidden_dim))

        #self.reconstructed_embeddings = torch.zeros(num_queries, hidden_dim)
        self.reconstructed_embeddings[0:mask_q_num] = self.query_embed(torch.tensor(0))
        self.reconstructed_embeddings[mask_q_num:mask_q_num+pose_q_num] = self.query_embed(torch.tensor(1))
        for i in range(num_queries - mask_q_num - pose_q_num):
            self.reconstructed_embeddings[mask_q_num+pose_q_num+i] = self.query_embed(torch.tensor(i+2))
        
        print(f"self.reconstructed_embeddings grad:{self.reconstructed_embeddings.requires_grad}")
        """


        self.QFormer_kv_pos_embed = nn.Parameter(torch.zeros(1, QFormer_num_queries,hidden_dim)) # QFormer pos embedding
        nn.init.normal_(self.QFormer_kv_pos_embed, std=0.02)

        self.use_crossattn_and_feature = use_crossattn_and_feature
        if use_crossattn_and_feature:
            assert crossattn_mask_threshold != 0.0,"crossattn_mask_threshold can not be zero !"
            self.crossattn_mask_threshold = crossattn_mask_threshold
            self.scale = (hidden_dim / nheads) ** -0.5
            #self.to_k = nn.Linear(mask_dim, mask_dim, bias=False)
            #self.to_v = nn.Linear(mask_dim, mask_dim, bias=False)
            self.to_k = nn.Conv2d(hidden_dim,hidden_dim,kernel_size=1,stride=1,padding=0)
            self.to_v1 = torch.nn.Conv2d(hidden_dim,hidden_dim,kernel_size=1,stride=1,padding=0)
            self.to_v2 = torch.nn.Conv2d(hidden_dim,hidden_dim,kernel_size=1,stride=1,padding=0)




        #self.QFormer_kv_attn_mask = nn.Parameter(torch.zeros(1,num_queries, QFormer_num_queries)) # mask
        #nn.init.normal_(self.kv_pos_embed, std=0.02)



        # level embedding (we always use 3 scales)
        self.num_feature_levels = 3
        self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
        self.input_proj = nn.ModuleList()
        self.input_proj.append( # 处理QFormer输出
            SemanticProjLayer(in_dim=in_channels[0], out_dim=hidden_dim,hidden_dim=2*in_channels[0])
        )
        for i in range(self.num_feature_levels): # todo-将图片信息维度转换到decoder需要的维度,分别处理feature2 3和mask_feature
            if in_channels[i+1] != hidden_dim or enforce_input_project:
                self.input_proj.append(nn.Sequential(Conv2d(in_channels[i+1], hidden_dim, kernel_size=1),Conv2d(hidden_dim, hidden_dim, kernel_size=3,stride=2 ,padding=1)))
                #self.input_proj.append(SpatialProjLayer(in_channels=in_channels[i],out_channels=hidden_dim,hidden_channels=2*in_channels[i]))
                for layer in self.input_proj[-1]:
                    weight_init.c2_xavier_fill(layer) # 实现了Xavier初始化（也称为Glorot初始化），目的是保持网络在前向和后向传播时的激活值的方差一致。这对于避免梯度消失或梯度爆炸问题很有帮助
            else:
                self.input_proj.append(nn.Sequential())

        self.ln_vision1 = LayerNorm(128) # swin-1 out layernorm
        self.ln_vision2 = LayerNorm(256) # swin-2
        self.ln_vision3 = LayerNorm(512) # swin-3

        # output FFNs
        #if self.mask_classification:
        #    self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
       

        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # todo- mask_embed输出维度就是mask_dim
        self.QFormer_attn_mask = nn.Linear(hidden_dim,hidden_dim)
        #----------------------------新增mask_feature和学到的feature不同的投影
        self.spatial_feature_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
        #self.QFormer_feature_embed = nn.Linear(hidden_dim, hidden_dim), qformer的特征应该不需要，只用了其他的特征传给后面


        self.QFormer_detach = QFormer_detach
        """
        self.semantic_feature_start = semantic_feature_start
        self.img_spatial_feature_start= img_spatial_feature_start
        self.spatial_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) # todo- mask_embed输出维度就是mask_dim
        self.spatial_norm = nn.LayerNorm(hidden_dim)
        self.semantic_norm = nn.LayerNorm(hidden_dim)
        self.img_spatial_norm = nn.LayerNorm([self.semantic_feature_start - img_spatial_feature_start,64,48]) # 20是通道数，cloth_mask、agn_mask、openpose_map(18)
        """





    """
    @classmethod
    def from_config(cls, cfg, in_channels, mask_classification):
        ret = {}
        ret["in_channels"] = in_channels
        ret["mask_classification"] = mask_classification
        
        ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
        ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
        ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
        # Transformer parameters:
        ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
        ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD

        # NOTE: because we add learnable query features which requires supervision,
        # we add minus 1 to decoder layers to be consistent with our loss
        # implementation: that is, number of auxiliary losses is always
        # equal to number of decoder layers. With learnable query features, the number of
        # auxiliary losses equals number of decoders plus 1.
        assert cfg.MODEL.MASK_FORMER.DEC_LAYERS >= 1
        ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1 # todo- 这里这个-1还是不太懂,因为配置文件中给出的是10，但本来是9层所以-1
        ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
        ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ

        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM

        return ret
    """
    @classmethod
    def from_config(cls, cfg):
        ret = {}
        ret["in_channels"] = cfg["in_channels"]
        #ret["mask_classification"] = cfg["mask_classification"]

        #ret["num_classes"] = cfg['num_classes']
        ret["hidden_dim"] = cfg['hidden_dim']
        ret["num_queries"] = cfg['num_queries']
        ret["nheads"] = cfg['nheads']
        ret["dim_feedforward"] = cfg['dim_feedforward']
        ret["dec_layers"] = cfg['dec_layers']
        ret["pre_norm"] = cfg['pre_norm']
        ret["enforce_input_project"] = cfg['enforce_input_project']
        ret["mask_dim"] = cfg['mask_dim']
        #ret["semantic_feature_start"] = cfg["semantic_feature_start"]
        #ret["img_spatial_feature_start"] = cfg["img_spatial_feature_start"]
        ret["predict_attention_mask"] = cfg['predict_attention_mask']
        ret["QFormer_num_queries"] = cfg['QFormer_num_queries']
        ret["use_crossattn_and_feature"] = cfg['use_crossattn_and_feature']
        ret["crossattn_mask_threshold"] = cfg['crossattn_mask_threshold']
        ret["mask_q_num"] = cfg['mask_q_num']
        ret["pose_q_num"] = cfg['pose_q_num']
        ret["QFormer_detach"] = cfg['QFormer_detach']

        return cls(**ret)

    def forward(self, x, mask_features,target_mask = None, mask = None): # todo-mask_features就是最大图片的特征，每经或一层都会利用这个预测每一个q需要观察的atten_mask，这其实就是最终的每个q的mask预测结果
        # x is a list of multi-scale feature， todo-x是输入的多尺度图片信息
        ## print(f"taskformer x shape:{[i.shape for i in x]}, mask_feature shape：{mask_features.shape}")
        bs,_,_ = mask_features.shape
        if self.QFormer_detach:
            x[0] = x[0].detach()

        #print(f"QFormer_Feature in TaskFormer: {x[0].shape}, grad:{x[0].requires_grad}, x1:{x[1].requires_grad}")


        # (b,l,c) -> (b,c,h,w)
        in_H = in_W = int(x[1].shape[1] ** 0.5)
        x[1] = self.ln_vision3(x[1]).permute(0, 2, 1).reshape(bs, -1, in_H, in_W)

        in_H = in_W = int(x[2].shape[1] ** 0.5)
        x[2] = self.ln_vision2(x[2]).permute(0, 2, 1).reshape(bs, -1, in_H, in_W)

        in_H = in_W = int(mask_features.shape[1] ** 0.5)
        mask_features = self.ln_vision1(mask_features)
        mask_features = mask_features.permute(0, 2, 1).reshape(bs, -1, in_H, in_W)


        mask_features = self.input_proj[3](mask_features) # 将输入的最大尺度图片的channel投影到hidden_dim

        assert len(x) == self.num_feature_levels
        src = []
        pos = []
        size_list = []

        # disable mask, it does not affect performance
        del mask
        #temp = x[0]
        #x[0] = x[0].unsqueeze(3).permute(0,2,1,3)
        for i in range(self.num_feature_levels):
            if i == 0:
                pos.append(self.QFormer_kv_pos_embed.repeat(x[i].shape[0], 1, 1).permute(0,2,1)) # (b,c, n)
                src.append(self.input_proj[i](x[i]).permute(0, 2, 1) + self.level_embed.weight[i][None, :, None]) # 将图片特征维度投影到dec需要的维度并加上level信息 [b,c,hw]
                size_list.append(x[i].shape[-2:]) # 记录每个维度图片特征的[h, w],x[0]的不会用到，因为x[0]维度是(b,num,hiddn)
            else:
                temp = self.input_proj[i](x[i])
                #print(x[i].shape,temp.shape)
                src.append(temp.flatten(2) + self.level_embed.weight[i][None, :, None]) # 将图片特征维度投影到dec需要的维度并加上level信息 [b,c,hw]
                pos.append(self.pe_layer(temp, None).flatten(2)) #
                size_list.append(temp.shape[-2:]) # 记录每个维度图片特征的[h, w],x[0]的不会用到，因为x[0]维度是(b,num,hiddn)
                del temp


            # flatten NxCxHxW to HWxNxC, todo-pos、src里的维度最终都调整成了这样
            pos[-1] = pos[-1].permute(2, 0, 1) # -> [hw,b,c]
            src[-1] = src[-1].permute(2, 0, 1)

        #_, bs, _ = src[0].shape

        # QxNxC
        reconstructed_embeddings = torch.zeros_like(self.query_embed.weight).to(mask_features.device)
        reconstructed_embeddings[0:self.mask_q_num] = self.query_embed(torch.tensor(0).to(mask_features.device))
        reconstructed_embeddings[self.mask_q_num:self.mask_q_num+self.pose_q_num] = self.query_embed(torch.tensor(1).to(mask_features.device))
        for i in range(self.num_queries - self.mask_q_num - self.pose_q_num):
            reconstructed_embeddings[self.mask_q_num+self.pose_q_num+i] = self.query_embed(torch.tensor(i+2).to(mask_features.device))
        
        query_embed = reconstructed_embeddings.unsqueeze(1).repeat(1, bs, 1) # [num_queries, hidden_dim] -> [num_queries,1, hidden_dim] -> [num_queries,bs, hidden_dim]
        ##print(f"TaskFormer query_position_embedding shape:{query_embed.shape}, grad:{query_embed.requires_grad}")
        ## query_embed = self.reconstructed_embeddings.unsqueeze(1).repeat(1, bs, 1)
        output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) # -> [num_queries,bs, hidden_dim]

        #predictions_class = []
        #predictions_mask = []

        # prediction heads on learnable query features
        # todo-这里就是利用query_feat提前产生第一个attn_mask，然后逐步送入每层再生成下一层需要的attn_mask
        mask_features_to_k = None
        mask_features_to_v1 = None
        mask_features_to_v2 = None
        if self.use_crossattn_and_feature:
            mask_features_to_k = self.to_k(mask_features)
            mask_features_to_v1 = self.to_v1(mask_features)
            mask_features_to_v2 = self.to_v2(mask_features)

        attn_mask = None
        if self.predict_attention_mask:
            attn_mask, feature, _ ,_= self.forward_prediction_heads(output, mask_features, QFormer_feature=src[0],
                                                      feature_to_k=mask_features_to_k,feature_to_v1=mask_features_to_v1 ,feature_to_v2=mask_features_to_v2 )
        #predictions_class.append(outputs_class)
        #predictions_mask.append(outputs_mask)
        features = []
        attn_masks = []
        pose_features = []
        pred_masks = []
        level_index = 0

        for i in range(self.num_layers):
            #level_index = i % self.num_feature_levels
            if self.predict_attention_mask:
                # 实际上如果得到的attn_mask中其中一个q屏蔽了所有图片token即(bh,q,len*len) 最后一行都是true，就全改成false让其继续学习
                attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # 检查每一行是否所有元素都是 True。如果是，它将这些元素设置为 False
            # attention: cross-attention first
            output = self.transformer_cross_attention_layers[i]( # todo-这里用到了src、pos
                output, src[level_index],
                memory_mask=attn_mask,
                memory_key_padding_mask=None,  # here we do not apply masking on padded region
                pos=pos[level_index], query_pos=query_embed
            )

            output = self.transformer_self_attention_layers[i](
                output, tgt_mask=None,
                tgt_key_padding_mask=None,
                query_pos=query_embed
            )
            
            # FFN
            output = self.transformer_ffn_layers[i](output)

            level_index = (i + 1) % self.num_feature_levels
            if level_index == 0:
                if i + 1 == self.num_layers: # 最后一次得到mask_feature尺寸相同的feature
                    attn_mask,feature,pose_feature,pred_mask = self.forward_prediction_heads(output, mask_features, target_mask=target_mask,attn_mask_target_size=mask_features.shape[-2:],
                                                    feature_to_k=mask_features_to_k ,feature_to_v1=mask_features_to_v1 ,feature_to_v2=mask_features_to_v2)
                else:
                    attn_mask,feature, pose_feature,pred_mask = self.forward_prediction_heads(output, mask_features, QFormer_feature=src[0])
            else:
                attn_mask,feature, pose_feature,pred_mask = self.forward_prediction_heads(output, mask_features, target_mask=target_mask,attn_mask_target_size=size_list[level_index],
                                                    feature_to_k=mask_features_to_k ,feature_to_v1=mask_features_to_v1 ,feature_to_v2=mask_features_to_v2)
    

            features.append(feature)
            attn_masks.append(attn_mask)
            pose_features.append(pose_feature)
            pred_masks.append(pred_mask)

            if not self.predict_attention_mask: 
                attn_mask = None

            #predictions_class.append(outputs_class)
            #predictions_mask.append(outputs_mask)
                
        ## ---------------第一阶段不考虑这里
        ## pred_mask_feature = rearrange(features[-1].permute(0,2,1), ' b (head q) l -> b (q head) l',head = self.num_heads)[:,:self.num_heads*self.mask_q_num,:]
        pred_mask_feature = None

        

        #semantic_feature,spatial_feature = self.forward_finnal_predict(output,mask_features)




        #attn_mask = self.forward_prediction_heads(spatial_feature, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels])
        #predictions_class.append(outputs_class)
        #predictions_mask.append(outputs_mask)
        #assert len(predictions_class) == self.num_layers + 1
        """
        out = {
            #'pred_logits': predictions_class[-1],
             'spatial_feature': spatial_feature,
            "semantic_feature":semantic_feature,
            #'aux_outputs': self._set_aux_loss(
                #predictions_class if self.mask_classification else None,
            #    predictions_mask
            #)
        }
        """
        return attn_masks, features, pose_features, pred_masks #, pred_mask_feature

    def forward_prediction_heads(self, output, mask_features,target_mask = None, attn_mask_target_size=None, QFormer_feature=None, feature_to_k=None, feature_to_v1=None,feature_to_v2=None):
        # 暂时任务QFormer_feature 维度 (b,num,channel)
        # 输出维度 outputs_class:[b,q,num_class+1],outputs_mask:[b,q,h,w],attn_mask:[b*head,q,hw]
        decoder_output = self.decoder_norm(output)
        decoder_output = decoder_output.transpose(0, 1) # [q, b, hidden_dim] -> [b,q,hidden_dim]
        if QFormer_feature is None:
            #outputs_class = self.class_embed(decoder_output)
            assert attn_mask_target_size is not None,"attn_mask_target_size can not be None!"
            mask_embed = self.mask_embed(decoder_output) # 每个q注意的mask
            feature_embed = self.spatial_feature_embed(decoder_output)

            if self.use_crossattn_and_feature:
                assert feature_to_k is not None and feature_to_v1 is not None and feature_to_v2 is not None,"to_k,to_v1,to_v2 can not be None!"
                b, c, heigh , width = feature_to_k.shape
                feature_to_k = rearrange(feature_to_k ,'b c h w -> b (h w) c')
                feature_to_k = rearrange(feature_to_k, 'b l (head c) -> (b head) l c',head = self.num_heads)

                feature_to_v1 = rearrange(feature_to_v1,'b c h w -> b (h w) c')
                feature_to_v1 = rearrange(feature_to_v1, 'b l (head c) -> (b head) l c',head = self.num_heads)

                feature_to_v2 = rearrange(feature_to_v2,'b c h w -> b (h w) c')
                feature_to_v2 = rearrange( feature_to_v2, 'b l (head c) -> (b head) l c',head = self.num_heads)

                spatial_featrue = torch.einsum( 'b l c , b n c -> b l n',feature_to_v1, feature_to_v2) # (b,l,l)这里可以在考虑一下用什么得到b l l的维度

                mask_embed = rearrange(mask_embed, 'b l (head c) -> (b head) l c',head = self.num_heads)

                sim = torch.einsum("bqc,blc->bql", mask_embed, feature_to_k) * self.scale

                sim = sim.softmax(dim=-1) # -> [bh,q,l] 
                pred_mask = rearrange(sim,'(b head) q l -> b (q head) l', head = self.num_heads)[:,:self.num_heads*self.mask_q_num,:]

                

                attn_mask = rearrange(sim,'b q (h w) -> b q h w', h = heigh, w = width) # 此时的b=b*head
                attn_mask = F.interpolate(attn_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False)
                attn_mask = (attn_mask.flatten(2) < self.crossattn_mask_threshold).bool().detach() # 这里门槛的设置--------------是否要detach操作
                
                feature = torch.einsum('b q l, b l n -> b q n', sim, spatial_featrue) #(b,l,l) 这里其实和上面一样可以再想想
                feature = rearrange(feature,'b q (h w) -> b q h w', h = heigh, w = width)
                feature = F.interpolate(feature, size=attn_mask_target_size, mode="bilinear", align_corners=False)
                feature = feature.flatten(2) # (b*h,q,hw)
                pose_feature = rearrange(feature, '(b head) q l -> b (q head) l',head = self.num_heads).permute(0,2,1)[:,:,self.mask_q_num*self.num_heads:self.mask_q_num*self.num_heads+self.pose_q_num*self.num_heads]
            else:

                #print(f"MaskFormer decoder_output: {decoder_output.size()}")
                #print(f"MaskFormer mask_embed: {mask_embed.size()}")
                #print(f"MaskFormer mask_features: {mask_features.size()}")
                mask_feature = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)

                feature = torch.einsum("bqc,bchw->bqhw", feature_embed, mask_features)

                pred_mask = mask_feature.flatten(2)[:,0,:] # [b,hw]

                # NOTE: prediction is of higher-resolution
                mask_feature = F.interpolate(mask_feature, size=attn_mask_target_size, mode="bilinear", align_corners=False)
                feature = F.interpolate(feature, size=attn_mask_target_size, mode="bilinear", align_corners=False)
                # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW]
                # must use bool type
                # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
                attn_mask = (mask_feature.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
                ##--------------原来复制self.num_heads 我也没想明白为什么原来这么写？？
                ##feature = (feature.flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1)) # [B, Q, H, W] -> [B*h, Q, HW]
                if target_mask is not None:
                    target_mask = F.interpolate(target_mask, size=attn_mask_target_size, mode="nearest").flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1).bool().detach() # [B*h, 1, HW]
                    attn_mask[:,0:(self.mask_q_num+self.pose_q_num),:] = target_mask.repeat(1,self.mask_q_num+self.pose_q_num,1)

                #attn_mask = attn_mask.detach() # [B*h, Q, HW]
                # todo- 这里得到的atten_mask就是每个q需要关注的范围信息
                pose_feature = ((feature.flatten(2))[:,0:(self.mask_q_num+self.pose_q_num),:]).permute(0,2,1) # [b,hw,35]
                spatial_feature =  ((feature.flatten(2))[:,(self.mask_q_num+self.pose_q_num):,:]).permute(0,2,1) # [b,hw,64]
            ##------------第一阶段把feature设置为false了
            ## feature = rearrange(feature, '(b head) q l -> b (head q) l',head = self.num_heads).permute(0,2,1)
            return_feature =  spatial_feature
            #pred_mask = rearrange(attn_mask,'(b head) q l -> b (q head) l', head = self.num_heads)[:,:self.num_heads*self.mask_q_num,:]
        else:
            QFormer_mask = self.QFormer_attn_mask(decoder_output)
            QFormer_feature = QFormer_feature.transpose(0,1)
            
            if self.use_crossattn_and_feature:
                
                #print(QFormer_mask.shape, QFormer_feature.shape)
                QFormer_mask = rearrange(QFormer_mask,"b l (head c) -> (b head) l c",head = self.num_heads)
                QFormer_feature = rearrange(QFormer_feature, "b l (head c) -> (b head) l c", head = self.num_heads)
                #print(QFormer_mask.shape, QFormer_feature.shape)
                attn_mask = torch.einsum("bqc,bnc->bqn", QFormer_mask, QFormer_feature)
                attn_mask = attn_mask.softmax(dim=-1)
                attn_mask = (attn_mask < self.crossattn_mask_threshold).bool().detach()
                
            else:
                attn_mask = torch.einsum("bqc,bnc->bqn", QFormer_mask, QFormer_feature)
                attn_mask = (attn_mask.sigmoid().unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool().detach()
            pose_feature = None
            return_feature = None
            pred_mask = None
        ##--print(f"attn_mask.requires_grad {attn_mask.requires_grad},feature grad :{feature.requires_grad if feature is not None else None},pose_feature grad:{pose_feature.requires_grad if pose_feature is not None else None},pred mask grad:{pred_mask.requires_grad if pred_mask is not None else None}")
        return  attn_mask, return_feature, pose_feature,pred_mask
    """
    def forward_finnal_predict(self, output, mask_features):
        semantic_feature = output[self.semantic_feature_start:, :, :] # 对semantic_feature的处理
        semantic_feature = self.semantic_norm(semantic_feature)
        semantic_feature = semantic_feature.transpose(0, 1)

        spatial_feature = output[:self.semantic_feature_start, :, :] # 格式是[q,b,hiddn_dim]
        spatial_feature = self.spatial_norm(spatial_feature)
        spatial_feature = spatial_feature.transpose(0, 1)
        spatial_embed = self.spatial_embed(spatial_feature)
        spatial_embed = torch.einsum("bqc,bchw->bqhw", spatial_embed, mask_features)
        #print(f"taskformer spatial_embed shape:{spatial_embed.shape}")
        #print(self.img_spatial_feature_start)
        spatial_embed[:,0:self.img_spatial_feature_start,:,:] = spatial_embed[:,0:self.img_spatial_feature_start,:,:].sigmoid() #[0,1]
        spatial_embed[:,self.img_spatial_feature_start:self.semantic_feature_start,:,:] =self.img_spatial_norm(spatial_embed[:,self.img_spatial_feature_start:self.semantic_feature_start,:,:]) #由于可学习的权重和偏移，因此范围不一定是[-1,1]，如果没有可学习参数是这个范围

        return semantic_feature, spatial_embed

    """

    #@torch.jit.unused # 用于指示被装饰的函数参数在TorchScript中不会被使用。TorchScript是PyTorch的一个子集，用于生成序列化和优化的模型，以便在不依赖Python解释器的情况下运行。
    #def _set_aux_loss(self, outputs_seg_masks):
        # 返回每一层dec除最后一层的输出结果字典列表
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        # 注释中说明TorchScript不支持包含非同质值的字典，比如字典中既包含张量（Tensor）又包含列表（list）
        #if self.mask_classification:
        #    return [
        #        {"pred_logits": a, "pred_masks": b}
        #        for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) # 不包含最后一个输出
        #    ]
        #else:
            #return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]


def build_TaskFormer_encoding(config): # todo-这里还没写完全，在考虑可以利用配置文件的形式导入，就不用都是命令行的形式
    model = MultiScaleTaskTransformerDecoder.from_config(config)
    pretrained_path = config.pop("pretrained_path", None)
    print(pretrained_path)
    if pretrained_path is not None:
        with open(pretrained_path, 'rb') as f:
            state_dict = pickle.load(f)
            # print(state_dict['model'].keys())
            state_dict = state_dict["model"]
            for name in list(state_dict.keys()):
                if "predictor" in name:
                    if "static_query"  in name:
                        state_dict[
                            name.replace("sem_seg_head.predictor.static_query", "query_feat")] = torch.from_numpy(state_dict.pop(name)) # 替换文件中static_query
                    elif "predictor.mask_embed" in name:
                        state_dict[name.replace("sem_seg_head.predictor.mask_embed", "spatial_embed")] = torch.from_numpy(state_dict[name])
                        state_dict[name.replace("sem_seg_head.predictor.mask_embed", "spatial_feature_embed")] = torch.from_numpy(state_dict[name])
                        state_dict[name.replace("sem_seg_head.predictor.mask_embed", "mask_embed")] = torch.from_numpy(state_dict.pop(name))
                    elif "predictor.decoder_norm" in name:
                        state_dict[name.replace("sem_seg_head.predictor.decoder_norm", "spatial_norm")] = torch.from_numpy(state_dict[name])
                        state_dict[name.replace("sem_seg_head.predictor.decoder_norm", "semantic_norm")] = torch.from_numpy(state_dict[name])
                        state_dict[name.replace("sem_seg_head.predictor.decoder_norm", "decoder_norm")] = torch.from_numpy(state_dict.pop(name))
                    else:
                        state_dict[name.replace("sem_seg_head.predictor.", "")] = torch.from_numpy(state_dict.pop(name))

            
            msg = model.load_state_dict(state_dict, strict=False)
            print("TaskFormer_missing_keys:",[i for i in msg.missing_keys])
            assert all(["input_proj" or "img_spatial_norm" in k for k in msg.missing_keys]),"模型中除输入图片维度投影层外，有其他键未匹配"
            print("achieve load TaskFormer weight successfully!")
            
            torch.cuda.empty_cache()

    
    return model





if __name__ == "__main__" :
    #cfg.merge_from_file("./ldm/models/transformer_decoder/configs.yaml")
    import yaml
    with open('./configs/configs.yaml', 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)['model']['params']['TaskFormer']

    # 创建模型实例
    model = build_TaskFormer_encoding(config)

    x = [torch.rand([1,16,768]), torch.rand([1,256*4,512]), torch.rand([1,1024*4,256])]
    mask_feature = torch.rand([1,4096*4, 128])

    attn_masks, features, pose_features, pred_masks, pred_mask_feature = model(x, mask_feature)
    #print(len(attn_masks), len(features))
    print(f"mask size: {[i.shape for i in attn_masks]}")
    print(f"feature size: {[i.shape if i is not None else i for i in features]}")
    print(f"pose_feature size: {[i.shape if i is not None else i for i in pose_features]}")
    print(f"pred_mask_feature size: {[i.shape if i is not None else i for i in pred_masks]}")
    print(f"mask_feature : {pred_mask_feature.shape}")




    # 打印模型参数以验证
    #print(model.__dict__)
    #print(model)
    #for name, parm in model.named_parameters():
    #    print(name)