"""
@author: Yanzuo Lu
@author: oliveryanzuolu@gmail.com
"""

import torch
import torch.nn as nn
from diffusers.models.attention import BasicTransformerBlock
from torchvision.transforms import Resize,InterpolationMode
from lavis.models.blip2_models.blip2 import LayerNorm


class AppearanceEncoder(nn.Module):
    def __init__(self, attn_residual_block_idx, inner_dims, ctx_dims, embed_dims, heads, depth,pose_dims,
                 to_self_attn, to_queries, to_keys, to_values, aspect_ratio, detach_input,
                 convin_kernel_size, convin_stride, convin_padding, 
                 input_attn_residual_block_idx,input_ctx_dims, input_inner_dims, input_embed_dims,input_heads,
                 use_down_or_up_residual,
                 ):
        super().__init__()
        self.attn_residual_block_idx = attn_residual_block_idx
        self.inner_dims = inner_dims
        self.ctx_dims = ctx_dims
        self.embed_dims = embed_dims
        self.to_self_attn = to_self_attn
        self.to_queries = to_queries
        self.to_keys = to_keys
        self.to_values = to_values
        self.aspect_ratio = aspect_ratio
        self.detach_input = detach_input

        self.use_down_or_up_residual = use_down_or_up_residual

        if use_down_or_up_residual == "all" or use_down_or_up_residual == "up":
            self.zero_conv_ins = []
            self.zero_conv_outs = []
            self.blocks = []
            
            # ouput residual
            for inner_dim, embed_dim, ctx_dim, num_head, kernel_size, stride, padding in \
                zip(inner_dims, self.embed_dims, self.ctx_dims, heads, convin_kernel_size, convin_stride, convin_padding):
                self.zero_conv_ins.append(nn.Conv2d(inner_dim, embed_dim, kernel_size=kernel_size,
                                                    stride=stride, padding=padding))
                self.zero_conv_outs.append(nn.Conv2d(embed_dim, ctx_dim, kernel_size=1, stride=1, padding=0))
                self.blocks.append(nn.Sequential(*[BasicTransformerBlock(
                    dim=embed_dim,
                    num_attention_heads=num_head,
                    attention_head_dim=embed_dim//num_head,
                    double_self_attention=True
                ) for _ in range(depth)]))

            self.blocks = nn.ModuleList(self.blocks)
            self.zero_conv_ins = nn.ModuleList(self.zero_conv_ins)
            self.zero_conv_outs = nn.ModuleList(self.zero_conv_outs)

            for n in self.zero_conv_ins.parameters():
                nn.init.zeros_(n)
            for n in self.zero_conv_outs.parameters():
                nn.init.zeros_(n)


        #----------------------------------------------------------
        #pose proj

        self.input_attn_residual_block_idx = input_attn_residual_block_idx
        self.input_ctx_dims = input_ctx_dims

        """
        self.pose_proj = []
        self.pose_ln = []
        for pose_dim,ctx_dim in zip(pose_dims,self.input_ctx_dims):
            self.pose_ln.append(LayerNorm(pose_dim))
            self.pose_proj.append(nn.Sequential(nn.Conv2d(pose_dim,pose_dim,kernel_size=3,stride=1,padding=1),nn.Conv2d(pose_dim, ctx_dim, kernel_size=1, stride=1, padding=0)))
        
        self.pose_ln = nn.ModuleList(self.pose_ln)
        self.pose_proj = nn.ModuleList(self.pose_proj)
        for n in self.pose_proj.parameters():
            nn.init.zeros_(n)
        """
        #---------------------------------------------------------------------------

        if use_down_or_up_residual == "all" or use_down_or_up_residual == "down":
            self.input_embed_dims = input_embed_dims
            self.input_heads = input_heads
            self.input_inner_dims = input_inner_dims

            self.input_zero_conv_ins = []
            self.input_zero_conv_outs = []
            self.input_blocks = []
            #for ctx_dim in input_ctx_dims:
            #    self.pose_proj.append(nn.Conv2d(pose_dim, ctx_dim, kernel_size=1, stride=1, padding=0))
            
            for inner_dim, embed_dim, ctx_dim, num_head,pose_dim in zip(input_inner_dims, self.input_embed_dims, self.input_ctx_dims, input_heads,pose_dims):
                self.input_zero_conv_ins.append(nn.Conv2d(inner_dim, embed_dim, kernel_size=1,
                                                    stride=1, padding=0))
                
                self.input_zero_conv_outs.append(nn.Conv2d(embed_dim, ctx_dim, kernel_size=1, stride=1, padding=0))
                self.input_blocks.append(nn.Sequential(*[BasicTransformerBlock(
                    dim=embed_dim,
                    num_attention_heads=num_head,
                    attention_head_dim=embed_dim//num_head,
                    double_self_attention=True
                ) for _ in range(depth)]))

            self.input_blocks = nn.ModuleList(self.input_blocks)
            self.input_zero_conv_ins = nn.ModuleList(self.input_zero_conv_ins)
            self.input_zero_conv_outs = nn.ModuleList(self.input_zero_conv_outs)


            for n in self.input_zero_conv_ins.parameters():
                nn.init.zeros_(n)
            for n in self.input_zero_conv_outs.parameters():
                nn.init.zeros_(n)
            

        


        # enable xformers
        
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(True, attention_op=None)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)
    

    def forward(self, features,pose_feature,pose_mask=None):
        #mid_step = []
        del pose_mask

        additional_residuals = {}
        if self.use_down_or_up_residual == "all" or self.use_down_or_up_residual == "up":        
            index = 0
            for i, block in enumerate(self.blocks):
                #temp = []
                hidden_states = features[index]
                if self.detach_input:
                    hidden_states = hidden_states.detach()

                in_H = in_W = int(features[index].shape[1] ** 0.5)
                hidden_states = features[index].permute(0, 2, 1).reshape(-1, self.inner_dims[i], in_H, in_W)
                #temp.append(hidden_states.shape) # 1.

                hidden_states = self.zero_conv_ins[i](hidden_states)
                H = W = hidden_states.shape[2]
                hidden_states = hidden_states.reshape(-1, self.embed_dims[i], H * W).permute(0, 2, 1)
                #temp.append(hidden_states.shape) # 2.

                hidden_states = block(hidden_states)


                hidden_states = hidden_states.permute(0, 2, 1).reshape(-1, self.embed_dims[i], H, W)
                #temp.append(hidden_states.shape) # 3.

                hidden_states = self.zero_conv_outs[i](hidden_states)
                hidden_states = hidden_states.reshape(-1, self.ctx_dims[i], H * W).permute(0, 2, 1)
                #temp.append(hidden_states.shape) # 4.

                #mid_step.append(temp)

                if self.to_self_attn:
                    if self.to_queries:
                        additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
                    elif self.to_keys:
                        additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
                    elif self.to_values:
                        additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
                else:
                    if self.to_keys and self.to_values:
                        additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_cross_attn_c"] = hidden_states
                    elif self.to_queries:
                        additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_cross_attn_q"] = hidden_states
                    elif self.to_keys:
                        additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_cross_attn_k"] = hidden_states
                    elif self.to_values:
                        additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_cross_attn_v"] = hidden_states

                if i != len(self.blocks) - 1 and self.ctx_dims[i] != self.ctx_dims[i + 1]: # 原来是self.inner_dims[i] != self.inner_dims[i + 1]，但目前都相同
                #if i % 2 == 1:
                    index += 1 # 移除第一个列表中的第一个元素
                
        #------------------------------------------------------------
        # input
        if self.use_down_or_up_residual == "all" or self.use_down_or_up_residual == "down":        

            index = 0
            for i, block in enumerate(self.input_blocks):
                #temp = []

                hidden_states = features[index]
                if self.detach_input:
                    hidden_states = hidden_states.detach()

                in_H = in_W = int(features[index].shape[1] ** 0.5)
                hidden_states = features[index].permute(0, 2, 1).reshape(-1, self.input_inner_dims[i], in_H, in_W)
                #temp.append(hidden_states.shape) # 1.

                hidden_states = self.input_zero_conv_ins[i](hidden_states)
                H = W = hidden_states.shape[2]
                hidden_states = hidden_states.reshape(-1, self.input_embed_dims[i], H * W).permute(0, 2, 1)
                #temp.append(hidden_states.shape) # 2.

                hidden_states = block(hidden_states)


                hidden_states = hidden_states.permute(0, 2, 1).reshape(-1, self.input_embed_dims[i], H, W)
                #temp.append(hidden_states.shape) # 3.

                hidden_states = self.input_zero_conv_outs[i](hidden_states)
                hidden_states = hidden_states.reshape(-1, self.input_ctx_dims[i], H * W).permute(0, 2, 1)
                #temp.append(hidden_states.shape) # 4.

                #mid_step.append(temp)

                if self.to_self_attn:
                    if self.to_queries:
                        additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_self_attn"] = hidden_states
                    elif self.to_keys:
                        additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_self_attn"] = hidden_states
                    elif self.to_values:
                        additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_self_attn"] = hidden_states
                else:
                    if self.to_keys and self.to_values:
                        additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_cross_attn_c"] = hidden_states
                    elif self.to_queries:
                        additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_cross_attn_q"] = hidden_states
                    elif self.to_keys:
                        additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_cross_attn_k"] = hidden_states
                    elif self.to_values:
                        additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_cross_attn_v"] = hidden_states

                #if i != len(self.blocks) - 1 and self.ctx_dims[i] != self.ctx_dims[i + 1]: # 原来是self.inner_dims[i] != self.inner_dims[i + 1]，但目前都相同
                if i % 2 == 1:
                    index += 1 # 移除第一个列表中的第一个元素
            
        
        
        """
        #----------------------------------------------------------------
        pose_residuals = {}
        for i, proj in enumerate(self.pose_proj):
            b,c,h,w = pose_feature[0].shape
            out_pose = self.pose_ln[i](pose_feature[0].reshape(b, c, h * w).permute(0, 2, 1))
            out_pose = proj(out_pose.permute(0, 2, 1).reshape(b, c, h, w))
            #----------------------------
            #temp_mask = Resize([out_pose.shape[-2],out_pose.shape[-1]],interpolation=InterpolationMode.NEAREST)(pose_mask) # z shape:[64,64]
            #out_pose = temp_mask * out_pose
            #-----------------------------
            pose_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_pose_residual"] = out_pose # [b,dim,h,w]
            if i % 2 == 1:
                pose_feature.pop(0)
        """
        pose_residuals = None

        return additional_residuals, pose_residuals#, mid_step
    
def build_Decoder(config):
    model = AppearanceEncoder(**config)
    return model

if __name__ == "__main__": # 没有位置信息的添加
    """
    feature = [torch.rand([1,4096,128]),torch.rand([1,1024,256]),torch.rand([1,256,512]),torch.rand([1,64,1024])]
    model = AppearanceEncoder(
        convin_kernel_size= [1, 1, 1, 1, 1, 1, 1, 1, 1],
        convin_stride= [1, 1, 1, 1, 1, 1, 1, 1, 1],
        convin_padding= [0, 0, 0, 0, 0, 0, 0, 0, 0],
        attn_residual_block_idx = [11, 10, 9, 8, 7, 6, 5, 4, 3],
        inner_dims= [128, 128, 128, 256, 256, 256, 512, 512, 512],
        ctx_dims= [320, 320, 320, 640, 640, 640, 1280, 1280, 1280],
        embed_dims= [64, 64, 64, 128, 128, 128, 256, 256, 256],
        heads= [2, 2, 2, 4, 4, 4, 8, 8, 8],
        depth= 4,
        to_self_attn= False,
        to_queries= True,
        aspect_ratio = 1,
        to_keys = False,
        to_values= False,
        detach_input= False
    )
    """
    # down 12层：输入[1,9,32,32]
    # [torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 16, 16]), 
    # torch.Size([1, 640, 16, 16]), torch.Size([1, 640, 16, 16]), torch.Size([1, 640, 8, 8]), torch.Size([1, 1280, 8, 8]), 
    # torch.Size([1, 1280, 8, 8]), torch.Size([1, 1280, 4, 4]), torch.Size([1, 1280, 4, 4]), torch.Size([1, 1280, 4, 4])]
    import yaml
    with open('./configs/configs2_5_viton_revise.yaml', 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)['model']['params']['Decoder']
    feature = [torch.rand([1,4096,64]).to("cuda:3"),torch.rand([1,1024,64]).to("cuda:3"),torch.rand([1,256,64]).to("cuda:3")]
    pose_feature = [torch.rand([1,35,64,64]).to("cuda:3"),torch.rand([1,35,32,32]).to("cuda:3"),torch.rand([1,35,16,16]).to("cuda:3")]
    model = build_Decoder(config=config).to("cuda:3")
    residual, pose_residual = model(feature, pose_feature)
    print({k: v.shape for k,v in residual.items()})
    print({k: v.shape for k,v in pose_residual.items()})

    #print(mid_step)
