# Copyright (2024) Bytedance Ltd. and/or its affiliates 

# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 

#     http://www.apache.org/licenses/LICENSE-2.0 

# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 

#import sys
#sys.path.append('/workspace/model/Paint-by-Example-main-2')
#print(sys.path)


from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from ldm.modules.diffusionmodules.util import checkpoint
#from ldm.modules.transformer_decoder.mask2former_transformer_decoder import MLP

# todo-新增导入xformers
from typing import Optional, Any
import os
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
try:
    import xformers
    import xformers.ops
    XFORMERS_IS_AVAILBLE = True #and XFORMERS_IS_AVAILBLE
    print(f'XFORMERS_IS_AVAILBLE:{XFORMERS_IS_AVAILBLE}')

except:
    XFORMERS_IS_AVAILBLE = False
    print(f'XFORMERS_IS_AVAILBLE:{XFORMERS_IS_AVAILBLE}')




def exists(val):
    return val is not None


def uniq(arr):
    return{el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class SpatialSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
        q = rearrange(q, 'b c h w -> b (h w) c')
        k = rearrange(k, 'b c h w -> b c (h w)')
        w_ = torch.einsum('bij,bjk->bik', q, k)

        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = rearrange(v, 'b c h w -> b c (h w)')
        w_ = rearrange(w_, 'b i j -> b j i')
        h_ = torch.einsum('bij,bjk->bik', v, w_)
        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
        h_ = self.proj_out(h_)

        return x+h_



class DownsampleAvgPool(nn.Module): # todo-新增部分
    def __init__(self, sr_ratio):
        super(DownsampleAvgPool, self).__init__()
        self.sr_ratio = sr_ratio
        self.avgpool = nn.AvgPool1d(kernel_size=sr_ratio, stride=sr_ratio)

    def forward(self, x):
        x = self.avgpool(x)
        return x
    

        

class CrossAttention(nn.Module): # todo-修改了原crossAttention,对k、v的长度len做了平均池化下采样到sr_size，可能是为了减少计算量
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., sr_size=None):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim) # 如果第一个存在则返回第一个

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )


        # todo-新增
        self.sr_size = sr_size #  用于空间降采样的大小
        if sr_size is not None:
            # self.sr_k = DownsampleLinear(input_h * input_w, sr_size) # 线性降采样
            # self.sr_v = DownsampleLinear(input_h * input_w, sr_size)
            self.sr_k = DownsampleAvgPool(sr_size)
            self.sr_v = DownsampleAvgPool(sr_size)
    def forward(self, x, context=None, mask=None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) #  (batch_size * heads, sequence_length, dim)


        # todo-新增
        if _ATTN_PRECISION =="fp32":
            with torch.autocast(enabled=False, device_type = 'cuda'): # 禁用自动混合精度，确保在GPU上使用32位浮点数进行计算
                q, k = q.float(), k.float()
                if self.sr_size is not None:
                    k = self.sr_k(k.permute(0, 2, 1).contiguous()) # (batch_size * heads, dim, sr_size)
                    v = self.sr_v(v.permute(0, 2, 1).contiguous())
                    v = v.permute(0, 2, 1).contiguous() # (batch_size * heads,  sr_size，dim)
                    sim = torch.matmul(q, k) * self.scale # (batch_size * heads, sequence_length, sr_size)
                else:
                    sim = (q @ k.transpose(1, 2) * self.scale)
        else:
            if self.sr_size is not None:
                k = self.sr_k(k.permute(0, 2, 1).contiguous())
                v = self.sr_v(v.permute(0, 2, 1).contiguous())
                v = v.permute(0, 2, 1).contiguous()
                sim = torch.matmul(q, k) * self.scale
            else:
                sim = (q @ k.transpose(1, 2) * self.scale)
        del q, k



        if exists(mask): # todo-不确定mask的维度，因为k已经做了下采样，这里mask难道是k下采样后的维度？但不就无法刚开制定忽略哪里了吗？
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h) # (batch_size * heads, 1, sr_size)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        sim = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', sim, v) # (batch_size * heads, sequence_length, dim)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h) # (batch_size , sequence_length, dim * head)
        return self.to_out(out)

class MemoryEfficientCrossAttention(nn.Module):
    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
    def __init__(self, query_dim, context_dim=None, disable_selfattn = False,add_cond2selfattn=False ,add_text_weigh_to_selfattn = False,heads=8, dim_head=64, dropout=0.0, sr_size=None,use_cond_concat2selfattn=False,use_ip_adpter=False):
        super().__init__()   # todo-新增add_selfattn判断外部条件是否加入到selfattn上,当context_dim ！= None才有效
        #print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
        #      f"{heads} heads.")
        self.add_cond2selfattn = add_cond2selfattn
        self.disable_selfattn = disable_selfattn
        self.use_cond_concat2selfattn = use_cond_concat2selfattn
        self.use_ip_adpter = use_ip_adpter
        

        inner_dim = dim_head * heads
        #self.flag = False
        #if context_dim is not None:
        #    self.flag = True
        dim = default(context_dim, query_dim)
        if add_cond2selfattn or use_cond_concat2selfattn or use_ip_adpter:
            dim= query_dim

        self.add_text_weigh_to_selfattn = add_text_weigh_to_selfattn

        self.heads = heads
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(dim, inner_dim, bias=False)
        self.to_v = nn.Linear(dim, inner_dim, bias=False)

        #if context_dim is not None: # 论文中的image_to_k、v的操作
        if add_cond2selfattn:
            self.i_to_k = nn.Linear(dim, inner_dim, bias=False)
            self.i_to_v = nn.Linear(dim, inner_dim, bias=False)
            #self.position_embedding = nn.Embedding(2, inner_dim) # todo- 这部分作为query的位置信息
            #torch.nn.init.zeros_(self.position_embedding.weight)  # 初始化为0
            self.proj_zero = nn.Linear(dim_head, dim_head,bias=False)
            #torch.nn.init.zeros_(self.proj_zero.weight) 这里暂时改了一下，原来之前训练都是这样的，因为考虑到decode呢里也输出0了，所以这里不尝试再设置初始化为0


           
        if self.add_text_weigh_to_selfattn:
            #self.text_weight = MLP(input_dim=context_dim, hidden_dim=inner_dim,output_dim=inner_dim,num_layers=3)
            #print(context_dim, inner_dim)
            self.text_to_v = nn.Linear(context_dim, inner_dim, bias=False)
            self.text_to_k = nn.Linear(context_dim, inner_dim, bias= False)
            nn.init.zeros_(self.text_to_k.weight)
            nn.init.zeros_(self.text_to_v.weight)


            

            

        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
        self.attention_op: Optional[Any] = None

        self.sr_size = sr_size # 本文新增
        if sr_size is not None:
            self.sr_k = DownsampleAvgPool(sr_size)
            self.sr_v = DownsampleAvgPool(sr_size)

    def forward(self, x, context=None, crossattn_img_weight=1.0,selfattn_img_weight=1.0, mask=None): # todo- img_weight决定text_k、v和image_k、v时image的缩放比，然后缩放后再拼接
        #print(f"cond len: {len(context)}, shape:{[ i.shape if i is not None else None for i in context]}")
        x_len = x.size()[1]
        if self.use_cond_concat2selfattn:
            x = torch.cat((x, context[1]), dim = 1)
        

        q = self.to_q(x)  # context格式原来应该是[imgae_embed, text_embed]，我需要改成[semantic_imgae_embed,Cloth_text_embed,Spatial_text_embed,Spatial_imgaeA_embed,Spatial_imageB_embed]
        #print(f"交叉注意力输入信息：{len(context)}")
        #     cond = [semantic_embedding,spatial_embedding,text_embedding,pose_embedding]
        #print(f"input shape : {[i.shape if i is not None else None for i in context]}")
        #print(f"attn x shape: {x.shape}")
        i_k = None
        i_v = None
        if self.disable_selfattn:
            if self.add_cond2selfattn: # 改变的自注意力
                #print(f"use add_con2selfattn")
                k = self.to_k(x) #+ self.position_embedding(torch.tensor(0).to(x.device)).to(q.dtype) 
                v = self.to_v(x)
                assert context[1] is not None,"Fool!when add_cond2selfattn is true,context['spatial_embedding'] can'n be None"
                
                i_k = self.i_to_k(context[1]) #+ self.position_embedding(torch.tensor(1).to(x.device)).to(q.dtype) 
                i_v = self.i_to_v(context[1])
                            
                ## k = torch.cat((k, selfattn_img_weight * i_k), dim=1)
                ## v = torch.cat((v, selfattn_img_weight * i_v), dim=1)
            elif self.use_cond_concat2selfattn:
                #print(f"uses cond_concat2selfattn")
                k = self.to_k(x)
                v = self.to_v(x)
            elif self.use_ip_adpter:
                #print("use ip-adapter")
                assert context[1] is not None,"Fool!when add_cond2selfattn is true,context['spatial_embedding'] can'n be None"
                k = self.to_k(context[1])
                v = self.to_v(context[1])

                
                #print(f"执行特殊自注意力注入空间信息， context['spatial_embedding'].shape:{context[1].shape}")
            else: # 改变的交叉注意力
                #print("普通交叉注意力")
                assert context[0] is not None,"Fool!when add_img2crossattn is true ,context['semantic_embedding'] can'n be None"
                k = self.to_k(context[0])
                v = self.to_v(context[0])
                #i_k = self.i_to_k(context[0])
                #i_v = self.i_to_v(context[0])
                #k = torch.cat((k, crossattn_img_weight * i_k), dim=1)
                #v = torch.cat((v, crossattn_img_weight * i_v), dim=1)
                #print(f"执行交叉注意力注入属性信息,context['semantic_embedding'].shape:{context[0].shape} ")
        else:
            #print("普通自注意力")
            k = self.to_k(x)
            v = self.to_v(x)
            #print("执行普通self-attention")
        """
        b, _, _ = q.shape
        q, k, v, text_weight = map(
            lambda t: t.unsqueeze(3)
            .reshape(b, t.shape[1], self.heads, self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b * self.heads, t.shape[1], self.dim_head)
            .contiguous(),
            (q, k, v, text_weight),
        )# [b, len,head*dim_head]->[b,len,dim,1]->[b,len,head,dim_head]->[b,head,len,dim_head]->[b*head,len,dim_head]
        """
        ## print(f"q shape:{q.shape}, k shape:{k.shape}, v shape:{v.shape}")
        if i_k is not None: 
            q, k, v,i_k,i_v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads).contiguous(), (q, k, v,i_k,i_v))
        else:
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=self.heads).contiguous(), (q, k, v))

        #weight_map = None
        if self.add_text_weigh_to_selfattn:
            assert context[2] is not None,"Fool!when add_img2crossattn is true ,context['text_embedding'] can'n be None"
            text_to_v = self.text_to_v(context[2])
            text_to_k = self.text_to_k(context[2])
            text_to_k =  rearrange(text_to_k, 'b n (h d) -> (b h) n d', h=self.heads)
            text_to_v =  rearrange(text_to_v, 'b n (h d) -> (b h) n d', h=self.heads)
            #print(f"q shape:{q.shape}, k shape:{k.shape}, v shape:{v.shape}, text_weight shape:{text_weight.shape}")
            ## k = einsum('b i d, b j d -> b i j', k, text_to_k) * self.scale # [4096*2,320] *[77,320] -> [4096*2,77]
            ## v = einsum('b i d, b j d -> b i j', v, text_to_v) * self.scale
            k = xformers.ops.memory_efficient_attention(k, text_to_k, text_to_v, attn_bias=None, op=self.attention_op)
            v = xformers.ops.memory_efficient_attention(v, text_to_k, text_to_v, attn_bias=None, op=self.attention_op)

            #weight_map = einsum('b i d, b j d -> b i j', q_text, k_text) * (k_text.size(-1) ** 0.5)
            print(f"add text weigh to selfattn,text embedding size:{context[2]}")
        if self.sr_size is not None: # 新增一个kv的下采样， 不知道这里是否设计i_q、i_k？？？
            ## print("attention downsample to self.sr_size !")
            k = self.sr_k(k.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
            v = self.sr_v(v.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()

        # actually compute the attention, what we cannot get enough of
        # 这里可以尝试attention_map + bias
        #out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) todo-没有xformer，测试期间使用老方法，真正训练时候进行修改


        # [b*head,len,dim_head],计算的attn_map * weight
        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
        if i_k is not None:
            i_out = xformers.ops.memory_efficient_attention(q, i_k, i_v, attn_bias=None, op=self.attention_op)
            i_out_proj = self.proj_zero(i_out)
            out =  out + i_out_proj

        if self.use_cond_concat2selfattn:
            out = out[:,:x_len]

        
        """
        attn_scores = einsum('b i d, b j d -> b i j', q, k) * self.scale
        if self.add_text_weigh_to_selfattn:
            #print(f"特殊自注意力层 q: {q.shape}, k : {k.shape}, v: {v.shape}, text_to_k: {text_to_k.shape}")
            k_text = einsum('b i d, b j d -> b i j', k, text_to_k) * self.scale
            q_text = einsum('b i d, b j d -> b i j', q, text_to_q) * self.scale
            weight_map = einsum('b i d, b j d -> b i j', q_text, k_text) * (k_text.size(-1) ** 0.5)
            #当考虑是乘积改变原attn_scores时
            #weight_map = weight_map.softmax(dim=-1)
            #attn_scores = attn_scores * weight_map
            # 当考虑是bias时  
            attn_scores = attn_scores + weight_map

        attn_weights = attn_scores.softmax(dim=-1)
        out = einsum('b i j, b j d -> b i d', attn_weights, v)
        """
        """
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (q.size(-1) ** 0.5)
        if self.add_text_weigh_to_selfattn:
            k_text = torch.matmul(k, text_weight.transpose(-2,-1)) / (q.size(-1) ** 0.5)
            q_text = torch.matmul(q,text_weight.transpose(-2,-1)) / (q.size(-1) ** 0.5)
            weight_map = F.softmax(torch.matmul(q_text, k_text.transpose(-2,-1)) / (k_text.size(-1) ** 0.5), dim=-1) # 每个注意力图的权重矩阵
            attn_scores = torch.dot(attn_scores, weight_map)

        attn_weights = F.softmax(attn_scores, dim=-1)
        out = torch.matmul(attn_weights, v)
        """


        # 不同注意力图的权重



        if exists(mask):
            raise NotImplementedError
        out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads)
        """
        out = (
            out.unsqueeze(0)
            .reshape(b, self.heads, out.shape[1], self.dim_head)
            .permute(0, 2, 1, 3)
            .reshape(b, out.shape[1], self.heads * self.dim_head)
        )# 最终转换成[b,len,head*dim_head]
        """
        return self.to_out(out)


class BasicTransformerBlock(nn.Module):
    ATTENTION_MODES = {
        "softmax": CrossAttention,  # vanilla attention，但也增加了可以kv进行下采样的操作
        "softmax-xformers": MemoryEfficientCrossAttention # 就是论文中的attention
    }
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,add_cond2selfattn=False,
                  sr_size=None,add_text_weigh_to_selfattn=False, disable_selfattn = False,
                  use_cond_concat2selfattn=False,use_ip_adpter=False,scale = 1,
                  ): # 新增disable_self_attn=False, sr_size=None，两个参数，todo-这里disable_self_attn代表使用交叉注意力、自注意力注入空间属性
        super().__init__()
        #attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" todo-因为没有xformer先停止，后面训练再开
        attn_mode = "softmax-xformers"

        assert attn_mode in self.ATTENTION_MODES
        attn_cls = self.ATTENTION_MODES[attn_mode]

        #self.disable_self_attn = disable_self_attn # todo-表示不使用传统自注意力，采用增加图片信息的自注意力或改成交叉注意力
        # 在unet中一般就是这是自注意力
        self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
                              context_dim=context_dim if disable_selfattn else None, sr_size=sr_size,add_text_weigh_to_selfattn=add_text_weigh_to_selfattn,
                              disable_selfattn = disable_selfattn,
                              #----------------
                              add_cond2selfattn=add_cond2selfattn,# todo-这里就是add_cond2selfattn=true，disable_self_attn=true就是拓展的自注意力
                              use_cond_concat2selfattn=use_cond_concat2selfattn,
                               )  # is a self-attention if not self.disable_self_attn
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        # unet中这是交叉注意力
        self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,disable_selfattn=True,
                              heads=n_heads, dim_head=d_head, dropout=dropout)  # 原论文想说只要disable_self_attn=False就是拓展的交叉注意力，self-attn if context is none
        
        self.use_ip_adpter = use_ip_adpter
        

        if self.use_ip_adpter:
            self.scale = scale
            self.attn3 = attn_cls(query_dim=dim, context_dim=context_dim,disable_selfattn=True,use_ip_adpter = use_ip_adpter,
                              heads=n_heads, dim_head=d_head, dropout=dropout)  # 原论文想说只要disable_self_attn=False就是拓展的交叉注意力，self-attn if context is none
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None, crossattn_img_weight=1.0,selfattn_img_weight=1.0):
        #print(f"attntion checkpoint flag: {self.checkpoint}")
        return checkpoint(self._forward, (x, context, crossattn_img_weight,selfattn_img_weight), self.parameters(), self.checkpoint)

    def _forward(self, x, context=None, crossattn_img_weight=1.0,selfattn_img_weight=1.0):
        #print(f"context[0] len:{[i.shape if i is not None else None for i in context[0]]}")
        #print(f"context len :{ [i.shape if i is not None else None for i in context]}")
        #print(f"selfattn x_input shape:{x.shape}")
        x = self.attn1(self.norm1(x), context=context ) + x
        """
        if context[1] is not None:
            x,_ = x.chunk(2,dim=-2)
            #print(f"selfattn x_output shape:{x.shape}, context[0][1] shape:{context[0][1].shape}")
        """
        x_in = x
        x = self.attn2(self.norm2(x), context=context, crossattn_img_weight=crossattn_img_weight,selfattn_img_weight=selfattn_img_weight) + x

        if self.use_ip_adpter:
            x_2 = self.attn3(self.norm2(x_in), context=context, crossattn_img_weight=crossattn_img_weight,selfattn_img_weight=selfattn_img_weight) + x_in
            x = x + self.scale * x_2

        x = self.ff(self.norm3(x)) + x
        return x


class SpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    NEW: use_linear for more efficiency instead of the 1x1 convs
    """
    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None, # todo-disable_self_attn始终没有用过，即可能考虑把原来的自注意力改成交叉注意力可能用到这个参数，但原文没有改
                 use_linear=False,add_cond2selfattn=False,add_text_weigh_to_selfattn=False,
                 use_checkpoint=True, sr_size=None, disable_selfattn = False,
                 use_cond_concat2selfattn=False,use_ip_adpter=False,
                 ): # 新增参数disable_self_attn=False, use_linear=False,use_checkpoint=True, sr_size=None
        super().__init__()
        self.add_cond2selfattn = add_cond2selfattn
        if exists(context_dim) and not isinstance(context_dim, list): # 新增
            context_dim = [context_dim]

        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        if not use_linear: # 新增，可选择用卷积或者linear将维度in_channels -> inner_dim
            self.proj_in = nn.Conv2d(in_channels,
                                     inner_dim,
                                     kernel_size=1,
                                     stride=1,
                                     padding=0)
        else:
            self.proj_in = nn.Linear(in_channels, inner_dim)


        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],add_cond2selfattn=add_cond2selfattn,
                                checkpoint=use_checkpoint, sr_size=sr_size,add_text_weigh_to_selfattn=add_text_weigh_to_selfattn, disable_selfattn=disable_selfattn,
                                use_cond_concat2selfattn=use_cond_concat2selfattn,use_ip_adpter=use_ip_adpter,
                                )
                for d in range(depth)]
        )

        if not use_linear: # 新增，卷积与linear可选择操作
            self.proj_out = zero_module(nn.Conv2d(inner_dim,
                                                  in_channels,
                                                  kernel_size=1,
                                                  stride=1,
                                                  padding=0))
        else:
            self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
        self.use_linear = use_linear # 新增是否选择使用卷积或linear操作

    def forward(self, x, context=None, crossattn_img_weight=1.0,selfattn_img_weight=1.0):
        # note: if no context is given , cross-attention defaults to self-attention
        if not isinstance(context, list): # 新增
            context = [context]

        x_in = x
        b, c, h, w = x.shape
        
        #------------------
        if  len(context[0]) == 4 and context[0][3] is not None:
            x = x + context[0][3]
            #print("pose residual + 1")
        """
        if context[0][1] is not None:
            #print(f"x shape: {x.shape}, spatial embed shape:{context[0][1].shape}")
            if len(context[0][1].shape) == 3:
                context[0][1] = rearrange(context[0][1], 'b (h w) c -> b c h w', h=h, w=w).contiguous()

            x = torch.cat([x,context[0][1]], dim = -2)
            b, c, h, w = x.shape
            #print(f"input spatial residual + 1, x_in shape:{x.shape}")
        """
        x = self.norm(x)
        if not self.use_linear:
            x = self.proj_in(x)
        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()

        if self.use_linear:
            x = self.proj_in(x)
        for i, block in enumerate(self.transformer_blocks): # 这里因为depth=1所以就是是一个基础的block即定义的一个自注意力+交叉注意力块，
                                            #todo - context[i]这有点疑惑，难道追钱depth != 1之后可以选择不同的条件注入?
            x = block(x, context=context[i], crossattn_img_weight=crossattn_img_weight,selfattn_img_weight=selfattn_img_weight)
        if self.use_linear:
            x = self.proj_out(x)
        """
        if context[0][1] is not None:
            h = int(h / 2)
        """
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
        """
        if context[0][1] is not None:
            x,_ = x.chunk(2,dim=-2)
            #print(f"input spatial residual + 1")
            #print(f"x_out shape:{x.shape}")
        """

        if not self.use_linear:
            x = self.proj_out(x)
        return x + x_in

if __name__ == "__main__":
    spatial_cond = torch.rand([1, 1024, 320]).cuda(1)
    semantic_cond = torch.rand([1, 35, 768]).cuda(1)
    text_cond = torch.rand([1,20,768]).cuda(1)
    pose_cond = torch.rand([1,320,32,32]).cuda(1)
    cond = [[semantic_cond, spatial_cond, text_cond,pose_cond]]
    model = SpatialTransformer(in_channels=320,n_heads=8,d_head=40,context_dim=768,add_cond2selfattn=True,add_text_weigh_to_selfattn=False,disable_selfattn=True).cuda(1)
    x = torch.rand([1,320,32,32]).cuda(1)
    #print(model)
    out = model(x,cond).size() # 最终形状和维度一点没变
    print(out)