'''
    AudioCrossAttention
    Args:
        face_mask: as the attention mask, since the audio only control
'''
from typing import List, Optional
import torch
from diffusers.models.attention import (AdaLayerNorm, Attention, FeedForward)
from torch import nn


class AudioTemporalBasicTransformerBlock(nn.Module):
    """
    A PyTorch module designed to handle audio data within a transformer framework, including temporal attention mechanisms.

    Attributes:
        dim (int): The dimension of the input and output embeddings.
        num_attention_heads (int): The number of attention heads.
        attention_head_dim (int): The dimension of each attention head.
        dropout (float): The dropout probability.
        cross_attention_dim (Optional[int]): The dimension of the cross-attention mechanism.
        activation_fn (str): The activation function for the feed-forward network.
        num_embeds_ada_norm (Optional[int]): The number of embeddings for adaptive normalization.
        attention_bias (bool): If True, uses bias in the attention mechanism.
        only_cross_attention (bool): If True, only uses cross-attention.
        upcast_attention (bool): If True, upcasts the attention mechanism to float32.
        unet_use_cross_frame_attention (Optional[bool]): If True, uses cross-frame attention in UNet.
        unet_use_temporal_attention (Optional[bool]): If True, uses temporal attention in UNet.
        depth (int): The depth of the transformer block.
        unet_block_name (Optional[str]): The name of the UNet block.
        stack_enable_blocks_name (Optional[List[str]]): The list of enabled blocks in the stack.
        stack_enable_blocks_depth (Optional[List[int]]): The list of depths for the enabled blocks in the stack.
    """
    def __init__(
        self,
        dim: int,
        num_attention_heads: int,
        attention_head_dim: int,
        dropout=0.0,
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        num_embeds_ada_norm: Optional[int] = None,
        attention_bias: bool = False,
        only_cross_attention: bool = False,
        upcast_attention: bool = False,
        unet_use_cross_frame_attention=None,
        unet_use_temporal_attention=None,
        depth=0,
        unet_block_name=None,
    ):  
        """
        Initializes the AudioTemporalBasicTransformerBlock module.

        Args:
           dim (int): The dimension of the input and output embeddings.
           num_attention_heads (int): The number of attention heads in the multi-head self-attention mechanism.
           attention_head_dim (int): The dimension of each attention head.
           dropout (float, optional): The dropout probability for the attention mechanism. Defaults to 0.0.
           cross_attention_dim (Optional[int], optional): The dimension of the cross-attention mechanism. Defaults to None.
           activation_fn (str, optional): The activation function to be used in the feed-forward network. Defaults to "geglu".
           num_embeds_ada_norm (Optional[int], optional): The number of embeddings for adaptive normalization. Defaults to None.
           attention_bias (bool, optional): If True, uses bias in the attention mechanism. Defaults to False.
           only_cross_attention (bool, optional): If True, only uses cross-attention. Defaults to False.
           upcast_attention (bool, optional): If True, upcasts the attention mechanism to float32. Defaults to False.
           unet_use_cross_frame_attention (Optional[bool], optional): If True, uses cross-frame attention in UNet. Defaults to None.
           unet_use_temporal_attention (Optional[bool], optional): If True, uses temporal attention in UNet. Defaults to None.
           depth (int, optional): The depth of the transformer block. Defaults to 0.
           unet_block_name (Optional[str], optional): The name of the UNet block. Defaults to None.
           stack_enable_blocks_name (Optional[List[str]], optional): The list of enabled blocks in the stack. Defaults to None.
           stack_enable_blocks_depth (Optional[List[int]], optional): The list of depths for the enabled blocks in the stack. Defaults to None.
        """
        super().__init__()
        self.only_cross_attention = only_cross_attention
        self.use_ada_layer_norm = num_embeds_ada_norm is not None
        self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
        self.unet_use_temporal_attention = unet_use_temporal_attention
        self.unet_block_name = unet_block_name
        self.depth = depth

        zero_conv_full = nn.Conv2d(dim, dim, kernel_size=1)
        self.zero_conv_full = zero_module(zero_conv_full)

        zero_conv_face = nn.Conv2d(dim, dim, kernel_size=1)
        self.zero_conv_face = zero_module(zero_conv_face)

        zero_conv_lip = nn.Conv2d(dim, dim, kernel_size=1)
        self.zero_conv_lip = zero_module(zero_conv_lip)
        
        # SC-Attn
        self.attn1 = Attention(
            query_dim=dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout,
            bias=attention_bias,
            upcast_attention=upcast_attention,
        )
        self.norm1 = (
            AdaLayerNorm(dim, num_embeds_ada_norm)
            if self.use_ada_layer_norm
            else nn.LayerNorm(dim)
        )

        # Cross-Attn
        if cross_attention_dim is not None:
            self.attn2_0 = Attention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
            )
            self.attn2_1 = Attention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
            )
            self.attn2_2 = Attention(
                query_dim=dim,
                cross_attention_dim=cross_attention_dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=upcast_attention,
            )
            self.attn2 = None
            # self.attn2 = Attention(
            #     query_dim=dim,
            #     cross_attention_dim=cross_attention_dim,
            #     heads=num_attention_heads,
            #     dim_head=attention_head_dim,
            #     dropout=dropout,
            #     bias=attention_bias,
            #     upcast_attention=upcast_attention,
            # )
            # self.attn2_0=None
        else:
            self.attn2 = None
            self.attn2_0 = None

        if cross_attention_dim is not None:
            self.norm2 = (
                AdaLayerNorm(dim, num_embeds_ada_norm)
                if self.use_ada_layer_norm
                else nn.LayerNorm(dim)
            )
        else:
            self.norm2 = None

        # Feed-forward
        self.ff = FeedForward(dim, dropout=dropout,
                              activation_fn=activation_fn)
        self.norm3 = nn.LayerNorm(dim)
        self.use_ada_layer_norm_zero = False



    def forward(
        self,
        hidden_states,
        encoder_hidden_states=None,
        timestep=None,
        attention_mask=None,
        full_mask=None, # （level, F, H, W）
        face_mask=None, # （level, F, H, W）
        lip_mask=None,  # （level, F, H, W）
        motion_scale=None,
        video_length=None,
    ):
        """
        Forward pass for the AudioTemporalBasicTransformerBlock.

        Args:
            hidden_states (torch.FloatTensor): The input hidden states.
            encoder_hidden_states (torch.FloatTensor, optional): The encoder hidden states. Defaults to None.
            timestep (torch.LongTensor, optional): The timestep for the transformer block. Defaults to None.
            attention_mask (torch.FloatTensor, optional): The attention mask. Defaults to None.
            full_mask (torch.FloatTensor, optional): The full mask. Defaults to None.
            face_mask (torch.FloatTensor, optional): The face mask. Defaults to None.
            lip_mask (torch.FloatTensor, optional): The lip mask. Defaults to None.
            video_length (int, optional): The length of the video. Defaults to None.

        Returns:
            torch.FloatTensor: The output tensor after passing through the AudioTemporalBasicTransformerBlock.
        """
        bs_f, channel, w, h = hidden_states.shape
        hidden_states = hidden_states.reshape(bs_f, channel, w*h).transpose(2, 1) # (bs, c, w, h) -> (bs, w*h, c)
        
        norm_hidden_states = (self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states))
        if self.unet_use_cross_frame_attention:
            hidden_states = (self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length,) + hidden_states)
        else: hidden_states = (self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states)
        
        ## gobally on the whole face
        norm_hidden_states = (self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states))
        if self.attn2 is not None:
            hidden_states = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,) + hidden_states
        ## hierarchicaly on the whole face
        elif self.attn2_0 is not None:
            level = self.depth
            # full feature
            full_hidden_states = (self.attn2_0(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,) * full_mask[level][:, :, None])
            full_hidden_states = full_hidden_states.reshape(bs_f, w, h, channel).permute(0, 3, 1, 2) # (bz, c, 64, 64)
            full_hidden_states = self.zero_conv_full(full_hidden_states).permute(0, 2, 3, 1).reshape(bs_f, -1, channel)
            # face feature
            face_hidden_state = (self.attn2_1(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,) * face_mask[level][:, :, None])
            face_hidden_state = face_hidden_state.reshape(bs_f, w, h, channel).permute(0, 3, 1, 2)
            face_hidden_state = self.zero_conv_face(face_hidden_state).permute(0, 2, 3, 1).reshape(bs_f, -1, channel)
            # lip feature
            lip_hidden_state = (self.attn2_2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,) * lip_mask[level][:, :, None]) # [32, 4096, 320]
            lip_hidden_state = lip_hidden_state.reshape(bs_f, w, h, channel).permute(0, 3, 1, 2)
            lip_hidden_state = self.zero_conv_lip(lip_hidden_state).permute(0, 2, 3, 1).reshape(bs_f, -1, channel)

            if motion_scale is not None:
                hidden_states = (motion_scale[0] * full_hidden_states + motion_scale[1] * face_hidden_state + motion_scale[2] * lip_hidden_state + hidden_states)
            else:
                hidden_states = (full_hidden_states + face_hidden_state + lip_hidden_state + hidden_states)

        hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
        hidden_states = hidden_states.transpose(2, 1).reshape(bs_f, channel, w, h)
        return hidden_states

def zero_module(module):
    """
    Zeroes out the parameters of a given module.
    Args:
        module (nn.Module): The module whose parameters need to be zeroed out.
    Returns:
        None.
    """
    for p in module.parameters():
        nn.init.zeros_(p)
    return module

if __name__ == '__main__':
    model = AudioTemporalBasicTransformerBlock(dim=320, num_attention_heads=8, attention_head_dim=88, cross_attention_dim=768)
    # hidden_states = torch.zeros((56, 4096, 320)) # (4, 14, 4096(64*64), 320)
    # encoder_hidden_states = torch.zeros((56, 32, 768)) # (4, 14, 32, 768)
    # full_mask = [torch.zeros((4, 14, 4096)).reshape(-1, 4096)]
    # face_mask = [torch.zeros((4, 14, 4096)).reshape(-1, 4096)]
    # lip_mask = [torch.zeros((4, 14, 4096)).reshape(-1, 4096)]
    hidden_states = torch.zeros((14, 320, 20, 36))
    encoder_hidden_states = torch.zeros((14, 12, 768))
    full_mask = [torch.zeros((14, 20*36)).reshape(-1, 20*36)]
    lip_mask = [torch.zeros((14, 20*36)).reshape(-1, 20*36)]
    face_mask = [torch.zeros((14, 20*36)).reshape(-1, 20*36)]
    out = model.forward(hidden_states, encoder_hidden_states, None, None, full_mask, face_mask, lip_mask)
    print(out.shape)
