"""
@author: Yanzuo Lu
@author: oliveryanzuolu@gmail.com
"""

import torch
import torch.nn as nn
from diffusers.models.attention import BasicTransformerBlock


class AppearanceEncoder(nn.Module):
    def __init__(self, attn_residual_block_idx, inner_dims, ctx_dims, embed_dims, heads, depth,pose_dim,
                 to_self_attn, to_queries, to_keys, to_values, aspect_ratio, detach_input,
                 convin_kernel_size, convin_stride, convin_padding, 
                 pose_attn_residual_block_idx,pose_ctx_dims):
        super().__init__()
        self.attn_residual_block_idx = attn_residual_block_idx
        self.inner_dims = inner_dims
        self.ctx_dims = ctx_dims
        self.embed_dims = embed_dims
        self.to_self_attn = to_self_attn
        self.to_queries = to_queries
        self.to_keys = to_keys
        self.to_values = to_values
        self.aspect_ratio = aspect_ratio
        self.detach_input = detach_input

        self.zero_conv_ins = []
        self.zero_conv_outs = []
        self.blocks = []
        """
        'convin_kernel_size': [1, 1, 1, 1, 1, 1, 1, 1, 1],
        'convin_stride': [1, 1, 1, 1, 1, 1, 1, 1, 1],
        'convin_padding': [0, 0, 0, 0, 0, 0, 0, 0, 0],
        'attn_residual_block_idx': [11, 10, 9, 8, 7, 6, 5, 4, 3],
        'inner_dims': [128, 128, 128, 256, 256, 256, 512, 512, 512],
        'ctx_dims': [320, 320, 320, 640, 640, 640, 1280, 1280, 1280],
        'embed_dims': [64, 64, 64, 128, 128, 128, 256, 256, 256],
        'heads': [2, 2, 2, 4, 4, 4, 8, 8, 8],
        'depth': 4,
        'to_self_attn': False,
        'to_queries': True,
        'to_keys': False,
        'to_values': False,
        'detach_input': False
        """
        for inner_dim, embed_dim, ctx_dim, num_head, kernel_size, stride, padding in \
            zip(inner_dims, self.embed_dims, self.ctx_dims, heads, convin_kernel_size, convin_stride, convin_padding):
            self.zero_conv_ins.append(nn.Conv2d(inner_dim, embed_dim, kernel_size=kernel_size,
                                                stride=stride, padding=padding))
            self.zero_conv_outs.append(nn.Conv2d(embed_dim, ctx_dim, kernel_size=1, stride=1, padding=0))
            self.blocks.append(nn.Sequential(*[BasicTransformerBlock(
                dim=embed_dim,
                num_attention_heads=num_head,
                attention_head_dim=embed_dim//num_head,
                double_self_attention=True
            ) for _ in range(depth)]))

        self.blocks = nn.ModuleList(self.blocks)
        self.zero_conv_ins = nn.ModuleList(self.zero_conv_ins)
        self.zero_conv_outs = nn.ModuleList(self.zero_conv_outs)

        #----------------------------------------------------------
        self.pose_attn_residual_block_idx = pose_attn_residual_block_idx
        self.pose_ctx_dims = pose_ctx_dims
        self.pose_proj = []
        for ctx_dim in pose_ctx_dims:
            self.pose_proj.append(nn.Conv2d(pose_dim, ctx_dim, kernel_size=1, stride=1, padding=0))
        self.pose_proj = nn.ModuleList(self.pose_proj)

        
        
        for n in self.zero_conv_ins.parameters():
            nn.init.zeros_(n)
        for n in self.zero_conv_outs.parameters():
            nn.init.zeros_(n)
        
        #--------
        # pose proj
        for n in self.pose_proj.parameters():
            nn.init.zeros_(n)
        
        # enable xformers
        
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
                module.set_use_memory_efficient_attention_xformers(True, attention_op=None)

            for child in module.children():
                fn_recursive_set_mem_eff(child)

        for module in self.children():
            if isinstance(module, torch.nn.Module):
                fn_recursive_set_mem_eff(module)
    

    def forward(self, features,pose_feature):
        #mid_step = []

        additional_residuals = {}

        for i, block in enumerate(self.blocks):
            #temp = []
            hidden_states = features[0]
            if self.detach_input:
                hidden_states = hidden_states.detach()

            in_H = in_W = int(features[0].shape[1] ** 0.5)
            hidden_states = features[0].permute(0, 2, 1).reshape(-1, self.inner_dims[i], in_H, in_W)
            #temp.append(hidden_states.shape) # 1.

            hidden_states = self.zero_conv_ins[i](hidden_states)
            H = W = hidden_states.shape[2]
            hidden_states = hidden_states.reshape(-1, self.embed_dims[i], H * W).permute(0, 2, 1)
            #temp.append(hidden_states.shape) # 2.

            hidden_states = block(hidden_states)


            hidden_states = hidden_states.permute(0, 2, 1).reshape(-1, self.embed_dims[i], H, W)
            #temp.append(hidden_states.shape) # 3.

            hidden_states = self.zero_conv_outs[i](hidden_states)
            hidden_states = hidden_states.reshape(-1, self.ctx_dims[i], H * W).permute(0, 2, 1)
            #temp.append(hidden_states.shape) # 4.

            #mid_step.append(temp)

            if self.to_self_attn:
                if self.to_queries:
                    additional_residuals[f"block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
                elif self.to_keys:
                    additional_residuals[f"block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
                elif self.to_values:
                    additional_residuals[f"block_{self.attn_residual_block_idx[i]}_self_attn"] = hidden_states
            else:
                if self.to_keys and self.to_values:
                    additional_residuals[f"block_{self.attn_residual_block_idx[i]}_cross_attn_c"] = hidden_states
                elif self.to_queries:
                    additional_residuals[f"block_{self.attn_residual_block_idx[i]}_cross_attn_q"] = hidden_states
                elif self.to_keys:
                    additional_residuals[f"block_{self.attn_residual_block_idx[i]}_cross_attn_k"] = hidden_states
                elif self.to_values:
                    additional_residuals[f"block_{self.attn_residual_block_idx[i]}_cross_attn_v"] = hidden_states

            if i != len(self.blocks) - 1 and self.ctx_dims[i] != self.ctx_dims[i + 1]: # 原来是self.inner_dims[i] != self.inner_dims[i + 1]，但目前都相同
            #if i % 2 == 1:
                features.pop(0) # 移除第一个列表中的第一个元素
        #--------------------------
        pose_residuals = {}
        for i, proj in enumerate(self.pose_proj):
            out_pose = proj(pose_feature[0])
            pose_residuals[f"block_{self.pose_attn_residual_block_idx[i]}_pose_residual"] = out_pose # [b,dim,h,w]
            if i % 2 == 1:
                pose_feature.pop(0)

        return additional_residuals, pose_residuals#, mid_step
    
def build_Decoder(config):
    model = AppearanceEncoder(**config)
    return model

if __name__ == "__main__": # 没有位置信息的添加
    """
    feature = [torch.rand([1,4096,128]),torch.rand([1,1024,256]),torch.rand([1,256,512]),torch.rand([1,64,1024])]
    model = AppearanceEncoder(
        convin_kernel_size= [1, 1, 1, 1, 1, 1, 1, 1, 1],
        convin_stride= [1, 1, 1, 1, 1, 1, 1, 1, 1],
        convin_padding= [0, 0, 0, 0, 0, 0, 0, 0, 0],
        attn_residual_block_idx = [11, 10, 9, 8, 7, 6, 5, 4, 3],
        inner_dims= [128, 128, 128, 256, 256, 256, 512, 512, 512],
        ctx_dims= [320, 320, 320, 640, 640, 640, 1280, 1280, 1280],
        embed_dims= [64, 64, 64, 128, 128, 128, 256, 256, 256],
        heads= [2, 2, 2, 4, 4, 4, 8, 8, 8],
        depth= 4,
        to_self_attn= False,
        to_queries= True,
        aspect_ratio = 1,
        to_keys = False,
        to_values= False,
        detach_input= False
    )
    """
    # down 12层：输入[1,9,32,32]
    # [torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 32, 32]), torch.Size([1, 320, 16, 16]), 
    # torch.Size([1, 640, 16, 16]), torch.Size([1, 640, 16, 16]), torch.Size([1, 640, 8, 8]), torch.Size([1, 1280, 8, 8]), 
    # torch.Size([1, 1280, 8, 8]), torch.Size([1, 1280, 4, 4]), torch.Size([1, 1280, 4, 4]), torch.Size([1, 1280, 4, 4])]
    import yaml
    with open('./configs/configs.yaml', 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)['model']['params']['Decoder']
    feature = [torch.rand([1,4096,800]),torch.rand([1,1024,800]),torch.rand([1,256,800])]
    model = build_Decoder(config=config)
    residual = model(feature)
    print({k: v.shape for k,v in residual.items()})
    #print(mid_step)
