"""
@author: Yanzuo Lu
@author: oliveryanzuolu@gmail.com
"""

import torch
import torch.nn as nn
from diffusers.models.attention import BasicTransformerBlock


class AppearanceEncoder(nn.Module):
    def __init__(self, attn_residual_block_idx, inner_dims, ctx_dims, embed_dims, heads, depth,pose_dim,
                 to_self_attn, to_queries, to_keys, to_values, aspect_ratio, detach_input,
                 convin_kernel_size, convin_stride, convin_padding, 
                 input_attn_residual_block_idx,input_ctx_dims, input_inner_dims, input_embed_dims,input_heads):
        super().__init__()
        self.attn_residual_block_idx = attn_residual_block_idx
        self.inner_dims = inner_dims
        self.ctx_dims = ctx_dims
        self.embed_dims = embed_dims
        self.to_self_attn = to_self_attn
        self.to_queries = to_queries
        self.to_keys = to_keys
        self.to_values = to_values
        self.aspect_ratio = aspect_ratio
        self.detach_input = detach_input

        self.zero_conv_ins = []
        self.zero_conv_outs = []
        self.blocks = []
        """
        'convin_kernel_size': [1, 1, 1, 1, 1, 1, 1, 1, 1],
        'convin_stride': [1, 1, 1, 1, 1, 1, 1, 1, 1],
        'convin_padding': [0, 0, 0, 0, 0, 0, 0, 0, 0],
        'attn_residual_block_idx': [11, 10, 9, 8, 7, 6, 5, 4, 3],
        'inner_dims': [128, 128, 128, 256, 256, 256, 512, 512, 512],
        'ctx_dims': [320, 320, 320, 640, 640, 640, 1280, 1280, 1280],
        'embed_dims': [64, 64, 64, 128, 128, 128, 256, 256, 256],
        'heads': [2, 2, 2, 4, 4, 4, 8, 8, 8],
        'depth': 4,
        'to_self_attn': False,
        'to_queries': True,
        'to_keys': False,
        'to_values': False,
        'detach_input': False
        """
        # ouput residual
        for inner_dim, embed_dim, ctx_dim, num_head, kernel_size, stride, padding in \
            zip(inner_dims, self.embed_dims, self.ctx_dims, heads, convin_kernel_size, convin_stride, convin_padding):
            self.zero_conv_ins.append(nn.Conv2d(inner_dim, embed_dim, kernel_size=kernel_size,
                                                stride=stride, padding=padding))
            self.zero_conv_outs.append(nn.Conv2d(embed_dim, ctx_dim, kernel_size=1, stride=1, padding=0))
            self.blocks.append(nn.Sequential(*[BasicTransformerBlock(
                dim=embed_dim,
                num_attention_heads=num_head,
                attention_head_dim=embed_dim//num_head,
                double_self_attention=True
            ) for _ in range(depth)]))

        self.blocks = nn.ModuleList(self.blocks)
        self.zero_conv_ins = nn.ModuleList(self.zero_conv_ins)
        self.zero_conv_outs = nn.ModuleList(self.zero_conv_outs)

        for n in self.zero_conv_ins.parameters():
            nn.init.zeros_(n)
        for n in self.zero_conv_outs.parameters():
            nn.init.zeros_(n)


        #----------------------------------------------------------
        # pose
        self.input_attn_residual_block_idx = input_attn_residual_block_idx
        self.input_ctx_dims = input_ctx_dims
        self.input_embed_dims = input_embed_dims
        self.input_heads = input_heads
        self.input_inner_dims = input_inner_dims

        self.pose_proj = []
        self.input_zero_conv_ins = []
        self.input_zero_conv_outs = []
        self.input_blocks = []
        #for ctx_dim in input_ctx_dims:
        #    self.pose_proj.append(nn.Conv2d(pose_dim, ctx_dim, kernel_size=1, stride=1, padding=0))
        
        for inner_dim, embed_dim, ctx_dim, num_head in zip(input_inner_dims, self.input_embed_dims, self.input_ctx_dims, input_heads):
            self.input_zero_conv_ins.append(nn.Conv2d(inner_dim, embed_dim, kernel_size=1,
                                                stride=1, padding=0))
            self.pose_proj.append(nn.Conv2d(pose_dim, ctx_dim, kernel_size=1, stride=1, padding=0))

            self.input_zero_conv_outs.append(nn.Conv2d(embed_dim, ctx_dim, kernel_size=1, stride=1, padding=0))
            self.input_blocks.append(nn.Sequential(*[BasicTransformerBlock(
                dim=embed_dim,
                num_attention_heads=num_head,
                attention_head_dim=embed_dim//num_head,
                double_self_attention=True
            ) for _ in range(depth)]))

        self.input_blocks = nn.ModuleList(self.input_blocks)
        self.input_zero_conv_ins = nn.ModuleList(self.input_zero_conv_ins)
        self.input_zero_conv_outs = nn.ModuleList(self.input_zero_conv_outs)

        for n in self.input_zero_conv_ins.parameters():
            nn.init.zeros_(n)
        for n in self.input_zero_conv_outs.parameters():
            nn.init.zeros_(n)
        
        self.pose_proj = nn.ModuleList(self.pose_proj)
        for n in self.pose_proj.parameters():
            nn.init.zeros_(n)





        # enable xformers
        
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(True, attention_op=None)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)
    

    def forward(self, features,pose_feature):
        #mid_step = []

        additional_residuals = {}
        index = 0
        for i, block in enumerate(self.blocks):
            #temp = []
            hidden_states = features[index]
            if self.detach_input:
                hidden_states = hidden_states.detach()

            in_H = in_W = int(features[index].shape[1] ** 0.5)
            hidden_states = features[index].permute(0, 2, 1).reshape(-1, self.inner_dims[i], in_H, in_W)
            #temp.append(hidden_states.shape) # 1.

            hidden_states = self.zero_conv_ins[i](hidden_states)
            H = W = hidden_states.shape[2]
            hidden_states = hidden_states.reshape(-1, self.embed_dims[i], H * W).permute(0, 2, 1)
            #temp.append(hidden_states.shape) # 2.

            hidden_states = block(hidden_states)


            hidden_states = hidden_states.permute(0, 2, 1).reshape(-1, self.embed_dims[i], H, W)
            #temp.append(hidden_states.shape) # 3.

            hidden_states = self.zero_conv_outs[i](hidden_states)
            hidden_states = hidden_states.reshape(-1, self.ctx_dims[i], H * W).permute(0, 2, 1)
            #temp.append(hidden_states.shape) # 4.

            #mid_step.append(temp)

            if self.to_self_attn:
                if self.to_queries:
                    additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
                elif self.to_keys:
                    additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
                elif self.to_values:
                    additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
            else:
                if self.to_keys and self.to_values:
                    additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_cross_attn_c"] = hidden_states
                elif self.to_queries:
                    additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_cross_attn_q"] = hidden_states
                elif self.to_keys:
                    additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_cross_attn_k"] = hidden_states
                elif self.to_values:
                    additional_residuals[f"output_block_{self.attn_residual_block_idx[i]}_cross_attn_v"] = hidden_states

            if i != len(self.blocks) - 1 and self.ctx_dims[i] != self.ctx_dims[i + 1]: # 原来是self.inner_dims[i] != self.inner_dims[i + 1]，但目前都相同
            #if i % 2 == 1:
                index += 1 # 移除第一个列表中的第一个元素
            
        #------------------------------------------------------------
        # input
        index = 0
        for i, block in enumerate(self.input_blocks):
            #temp = []

            hidden_states = features[index]
            if self.detach_input:
                hidden_states = hidden_states.detach()

            in_H = in_W = int(features[index].shape[1] ** 0.5)
            hidden_states = features[index].permute(0, 2, 1).reshape(-1, self.input_inner_dims[i], in_H, in_W)
            #temp.append(hidden_states.shape) # 1.

            hidden_states = self.input_zero_conv_ins[i](hidden_states)
            H = W = hidden_states.shape[2]
            hidden_states = hidden_states.reshape(-1, self.input_embed_dims[i], H * W).permute(0, 2, 1)
            #temp.append(hidden_states.shape) # 2.

            hidden_states = block(hidden_states)


            hidden_states = hidden_states.permute(0, 2, 1).reshape(-1, self.input_embed_dims[i], H, W)
            #temp.append(hidden_states.shape) # 3.

            hidden_states = self.input_zero_conv_outs[i](hidden_states)
            hidden_states = hidden_states.reshape(-1, self.input_ctx_dims[i], H * W).permute(0, 2, 1)
            #temp.append(hidden_states.shape) # 4.

            #mid_step.append(temp)

            if self.to_self_attn:
                if self.to_queries:
                    additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_self_attn"] = hidden_states
                elif self.to_keys:
                    additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_self_attn"] = hidden_states
                elif self.to_values:
                    additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_self_attn"] = hidden_states
            else:
                if self.to_keys and self.to_values:
                    additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_cross_attn_c"] = hidden_states
                elif self.to_queries:
                    additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_cross_attn_q"] = hidden_states
                elif self.to_keys:
                    additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_cross_attn_k"] = hidden_states
                elif self.to_values:
                    additional_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_cross_attn_v"] = hidden_states

            #if i != len(self.blocks) - 1 and self.ctx_dims[i] != self.ctx_dims[i + 1]: # 原来是self.inner_dims[i] != self.inner_dims[i + 1]，但目前都相同
            if i % 2 == 1:
               index += 1 # 移除第一个列表中的第一个元素
        
        
        
        
        
        
        #--------------------------
        pose_residuals = {}
        for i, proj in enumerate(self.pose_proj):
            out_pose = proj(pose_feature[0])
            pose_residuals[f"input_block_{self.input_attn_residual_block_idx[i]}_pose_residual"] = out_pose # [b,dim,h,w]
            if i % 2 == 1:
                pose_feature.pop(0)

        return additional_residuals, pose_residuals#, mid_step
    
def build_Decoder(config):
    model = AppearanceEncoder(**config)
    return model

if __name__ == "__main__": # 没有位置信息的添加
    """
    feature = [torch.rand([1,4096,128]),torch.rand([1,1024,256]),torch.rand([1,256,512]),torch.rand([1,64,1024])]
    model = AppearanceEncoder(
        convin_kernel_size= [1, 1, 1, 1, 1, 1, 1, 1, 1],
        convin_stride= [1, 1, 1, 1, 1, 1, 1, 1, 1],
        convin_padding= [0, 0, 0, 0, 0, 0, 0, 0, 0],
        attn_residual_block_idx = [11, 10, 9, 8, 7, 6, 5, 4, 3],
        inner_dims= [128, 128, 128, 256, 256, 256, 512, 512, 512],
        ctx_dims= [320, 320, 320, 640, 640, 640, 1280, 1280, 1280],
        embed_dims= [64, 64, 64, 128, 128, 128, 256, 256, 256],
        heads= [2, 2, 2, 4, 4, 4, 8, 8, 8],
        depth= 4,
        to_self_attn= False,
        to_queries= True,
        aspect_ratio = 1,
        to_keys = False,
        to_values= False,
        detach_input= False
    )
    """
    # down 12层：输入[1,9,32,32]
    # [torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 16, 16]), 
    # torch.Size([1, 640, 16, 16]), torch.Size([1, 640, 16, 16]), torch.Size([1, 640, 8, 8]), torch.Size([1, 1280, 8, 8]), 
    # torch.Size([1, 1280, 8, 8]), torch.Size([1, 1280, 4, 4]), torch.Size([1, 1280, 4, 4]), torch.Size([1, 1280, 4, 4])]
    import yaml
    with open('./configs/configs2_5_viton_revise.yaml', 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)['model']['params']['Decoder']
    feature = [torch.rand([1,4096,64]).to("cuda:3"),torch.rand([1,1024,64]).to("cuda:3"),torch.rand([1,256,64]).to("cuda:3")]
    pose_feature = [torch.rand([1,35,64,64]).to("cuda:3"),torch.rand([1,35,32,32]).to("cuda:3"),torch.rand([1,35,16,16]).to("cuda:3")]
    model = build_Decoder(config=config).to("cuda:3")
    residual, pose_residual = model(feature, pose_feature)
    print({k: v.shape for k,v in residual.items()})
    print({k: v.shape for k,v in pose_residual.items()})

    #print(mid_step)
