from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from diffusers.utils import deprecate
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, SwiGLU

from .modeling_normalization import (
    AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm, RMSNorm
)


def apply_rope(xq, xk, freqs_cis_q, freqs_cis_k=None):
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis_q[..., 0] * xq_[..., 0] + freqs_cis_q[..., 1] * xq_[..., 1]
    if freqs_cis_k is None:
        freqs_cis_k = freqs_cis_q
    xk_out = freqs_cis_k[..., 0] * xk_[..., 0] + freqs_cis_k[..., 1] * xk_[..., 1]
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        dim_out: Optional[int] = None,
        mult: int = 4,
        dropout: float = 0.0,
        activation_fn: str = "geglu",
        final_dropout: bool = False,
        inner_dim=None,
        bias: bool = True,
    ):
        super().__init__()
        if inner_dim is None:
            inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim

        if activation_fn == "gelu":
            act_fn = GELU(dim, inner_dim, bias=bias)
        if activation_fn == "gelu-approximate":
            act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
        elif activation_fn == "geglu":
            act_fn = GEGLU(dim, inner_dim, bias=bias)
        elif activation_fn == "geglu-approximate":
            act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
        elif activation_fn == "swiglu":
            act_fn = SwiGLU(dim, inner_dim, bias=bias)

        self.net = nn.ModuleList([])
        # project in
        self.net.append(act_fn)
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))

    def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            deprecate("scale", "1.0.0", deprecation_message)
        for module in self.net:
            hidden_states = module(hidden_states)
        return hidden_states

class VarlenSelfAttentionWithT5Mask:

    def __init__(self):
        pass

    def __call__(
            self, query, key, value, encoder_query, encoder_key, encoder_value, 
            heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None, audio_rotary_emb=None,
            temperature=1.0
        ):
        assert attention_mask is not None, "The attention mask needed to be set"

        encoder_length = encoder_query.shape[1] 
        num_stages = len(hidden_length)         
    
        encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) 
        qkv = torch.stack([query, key, value], dim=2) 

        i_sum = 0
        output_encoder_hidden_list = []
        output_hidden_list = []
    
        for i_p, length in enumerate(hidden_length): 
            encoder_qkv_tokens = encoder_qkv[i_p::num_stages] 
            qkv_tokens = qkv[:, i_sum:i_sum+length] 
            concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) 
            
            if image_rotary_emb is not None:
                concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])

            query, key, value = concat_qkv_tokens.unbind(2)  
            query = query.transpose(1, 2) 
            key = key.transpose(1, 2) 
            value = value.transpose(1, 2) 
            
            if hasattr(self, 'attn_dict'):
                assert len(hidden_length) == 1, "Only support single stage for attention map"
                self.attn_dict["query"] = query 
                self.attn_dict["key"] = key
                self.attn_dict["attn_mask"] = attention_mask[i_p]
            
            stage_hidden_states = F.scaled_dot_product_attention( 
                query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
            )
            
            stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3)  

            output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
            output_hidden_list.append(stage_hidden_states[:, encoder_length:])
            i_sum += length

        output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) 
        output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s d -> (b n) s d')
        output_hidden = torch.cat(output_hidden_list, dim=1)      

        return output_hidden, output_encoder_hidden


class VarlenSelfAttnSingle:

    def __init__(self):
        pass

    def __call__(
            self, query, key, value, heads, scale, 
            hidden_length=None, image_rotary_emb=None, attention_mask=None, audio_time_length_list=None, audio_rotary_emb=None,
            temperature=1.0, audio_temperature=1.0
        ):
        assert attention_mask is not None, "The attention mask needed to be set"
        num_stages = len(hidden_length)        

        i_sum = 0
        a_sum = 0
        output_hidden_list = []
        

        if audio_time_length_list is None:
            audio_time_length_list = hidden_length
            audio_cross_attn = False
        else:
            audio_cross_attn = True
        
        for i_p, (length, audio_length) in enumerate(zip(hidden_length, audio_time_length_list)):
            query_ = query[:, i_sum:i_sum+length]
            key_ = key[:, a_sum:a_sum+audio_length]
            value_ = value[:, a_sum:a_sum+audio_length]
            if image_rotary_emb is not None:
                if audio_rotary_emb is None:
                    audio_rotary_emb = image_rotary_emb
                query_, key_ = apply_rope(query_, key_, image_rotary_emb[i_p], audio_rotary_emb[i_p])
            query_ = query_.transpose(1, 2).contiguous()
            key_ = key_.transpose(1, 2).contiguous()
            value_ = value_.transpose(1, 2).contiguous()
                        
            if hasattr(self, 'attn_dict'):
                assert len(hidden_length) == 1, "Only support single stage for attention map"
                self.attn_dict["query"] = query_ 
                self.attn_dict["key"] = key_
                self.attn_dict["attn_mask"] = attention_mask[i_p]
                self.attn_dict["temperature"] = audio_temperature if audio_cross_attn else temperature

            stage_hidden_states = F.scaled_dot_product_attention(
                query_, key_, value_, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
            )
                
            stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3)   

            output_hidden_list.append(stage_hidden_states)
            i_sum += length
            a_sum += audio_length

        output_hidden = torch.cat(output_hidden_list, dim=1)

        return output_hidden


class Attention(nn.Module):

    def __init__(
        self,
        query_dim: int,
        cross_attention_dim: Optional[int] = None,
        heads: int = 8,
        dim_head: int = 64,
        dropout: float = 0.0,
        bias: bool = False,
        qk_norm: Optional[str] = None,
        added_kv_proj_dim: Optional[int] = None,
        added_proj_bias: Optional[bool] = True,
        out_bias: bool = True,
        only_cross_attention: bool = False,
        eps: float = 1e-5,
        processor: Optional["AttnProcessor"] = None,
        out_dim: int = None,
        context_pre_only=None,
        pre_only=False,
    ):
        super().__init__()

        self.inner_dim = out_dim if out_dim is not None else dim_head * heads
        self.inner_kv_dim = self.inner_dim
        self.query_dim = query_dim
        self.use_bias = bias
        self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim

        self.dropout = dropout
        self.out_dim = out_dim if out_dim is not None else query_dim
        self.context_pre_only = context_pre_only
        self.pre_only = pre_only

        self.scale = dim_head**-0.5
        self.heads = out_dim // dim_head if out_dim is not None else heads


        self.added_kv_proj_dim = added_kv_proj_dim
        self.only_cross_attention = only_cross_attention

        if self.added_kv_proj_dim is None and self.only_cross_attention:
            raise ValueError(
                "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
            )

        if qk_norm is None:
            self.norm_q = None
            self.norm_k = None
        elif qk_norm == "rms_norm":
            self.norm_q = RMSNorm(dim_head, eps=eps)
            self.norm_k = RMSNorm(dim_head, eps=eps)
        else:
            raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")

        self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)

        if not self.only_cross_attention:
            # only relevant for the `AddedKVProcessor` classes
            self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
            self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
        else:
            self.to_k = None
            self.to_v = None

        self.added_proj_bias = added_proj_bias
        if self.added_kv_proj_dim is not None:
            self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
            self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
            if self.context_pre_only is not None:
                self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)

        if not self.pre_only:
            self.to_out = nn.ModuleList([])
            self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
            self.to_out.append(nn.Dropout(dropout))

        if self.context_pre_only is not None and not self.context_pre_only:
            self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)

        if qk_norm is not None and added_kv_proj_dim is not None:
            if qk_norm == "fp32_layer_norm":
                self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
                self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
            elif qk_norm == "rms_norm":
                self.norm_added_q = RMSNorm(dim_head, eps=eps)
                self.norm_added_k = RMSNorm(dim_head, eps=eps)
        else:
            self.norm_added_q = None
            self.norm_added_k = None

        # set attention processor
        self.set_processor(processor)

    def set_processor(self, processor: "AttnProcessor") -> None:
        self.processor = processor

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        hidden_length: List = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
        audio_time_length_list: Optional[List] = None,
        audio_rotary_emb: Optional[torch.Tensor] = None,
        audio_temperature: float = 1.0,
    ) -> torch.Tensor:
        return self.processor(
            self,
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            attention_mask=attention_mask,
            hidden_length=hidden_length,
            image_rotary_emb=image_rotary_emb,
            audio_time_length_list=audio_time_length_list,
            audio_rotary_emb=audio_rotary_emb,
            audio_temperature=audio_temperature
        )


class FluxSingleAttnProcessor2_0:
    def __init__(self, use_flash_attn=False):
        self.use_flash_attn = use_flash_attn

        self.varlen_attn = VarlenSelfAttnSingle()

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None, 
        encoder_attention_mask: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None, 
        hidden_length: List = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
        audio_time_length_list: Optional[List] = None,
        audio_rotary_emb: Optional[torch.Tensor] = None,
        audio_temperature: float = 1.0,
    ) -> torch.Tensor:

        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states

        query = attn.to_q(hidden_states) 
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(query.shape[0], -1, attn.heads, head_dim)
        key = key.view(key.shape[0], -1, attn.heads, head_dim)
        value = value.view(value.shape[0], -1, attn.heads, head_dim)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)
            
        else:
            hidden_states = self.varlen_attn(
                query, key, value, 
                attn.heads, attn.scale, hidden_length, 
                image_rotary_emb, attention_mask,
                audio_time_length_list,
                audio_rotary_emb=audio_rotary_emb,
                audio_temperature=audio_temperature
            )

        return hidden_states


class FluxAttnProcessor2_0:
    def __init__(self, use_flash_attn=False):
        self.use_flash_attn = use_flash_attn
        self.varlen_attn = VarlenSelfAttentionWithT5Mask()

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor, 
        encoder_hidden_states: torch.FloatTensor = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        hidden_length: List = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
        audio_time_length_list: Optional[List] = None,
        audio_rotary_emb: Optional[torch.Tensor] = None,
        audio_temperature: float = 1.0,
    ) -> torch.FloatTensor:
        query = attn.to_q(hidden_states) 
        key = attn.to_k(hidden_states)   
        value = attn.to_v(hidden_states) 

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads 

        query = query.view(query.shape[0], -1, attn.heads, head_dim) 
        key = key.view(key.shape[0], -1, attn.heads, head_dim)
        value = value.view(value.shape[0], -1, attn.heads, head_dim)

        if attn.norm_q is not None: 
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

        encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
            encoder_hidden_states_query_proj.shape[0], -1, attn.heads, head_dim
        )
        encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 
            encoder_hidden_states_key_proj.shape[0], -1, attn.heads, head_dim
        )
        encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
            encoder_hidden_states_value_proj.shape[0], -1, attn.heads, head_dim
        )

        if attn.norm_added_q is not None: # RMSNorm
            encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
        if attn.norm_added_k is not None: # RMSNorm
            encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
        
        hidden_states, encoder_hidden_states = self.varlen_attn(
            query, key, value, 
            encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj,
            attn.heads, attn.scale, hidden_length,
            image_rotary_emb, attention_mask,
        ) 

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        return hidden_states, encoder_hidden_states


class FluxSingleTransformerBlock(nn.Module):

    def __init__(
        self, 
        dim, 
        num_attention_heads, 
        attention_head_dim, 
        mlp_ratio=4.0, 
        use_flash_attn=False, 
        use_audio_cross_attn=False,
        audio_cross_attn_dim=1024,
    ):
        super().__init__()
        self.mlp_hidden_dim = int(dim * mlp_ratio)

        self.norm = AdaLayerNormZeroSingle(dim)
        self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) 
        self.act_mlp = nn.GELU(approximate="tanh")
        self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) 
        processor = FluxSingleAttnProcessor2_0(use_flash_attn)
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            dim_head=attention_head_dim,
            heads=num_attention_heads,
            out_dim=dim,
            bias=True,
            processor=processor,
            qk_norm="rms_norm",
            eps=1e-6,
            pre_only=True,
        )
        if use_audio_cross_attn:
            self.audio_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
            self.audio_attn = Attention(
                query_dim=dim,
                cross_attention_dim=audio_cross_attn_dim,
                dim_head=attention_head_dim,
                heads=num_attention_heads,
                out_dim=dim,
                bias=True,
                processor=processor,
                qk_norm="rms_norm",
                eps=1e-6,
                pre_only=True,
            )
        else:
            self.audio_norm1 = None
            self.audio_attn = None
            

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        temb: torch.FloatTensor,
        encoder_attention_mask=None,
        attention_mask=None,
        hidden_length=None,
        image_rotary_emb=None,
        audio_hidden_states=None, 
        audio_attention_mask_list=None,
        audio_time_length_list=None,
        audio_rotary_emb=None,
        audio_temperature=1.0
    ):
        residual = hidden_states 
        if audio_hidden_states is not None and self.audio_attn is not None:
            batch_hidden_states = list(torch.split(hidden_states, hidden_length, dim=1))
            text_hidden_states = []
            batch_length_list = []
            image_rotary_emb_list = []
            for i_p in range(len(batch_hidden_states)):
                total_length = batch_hidden_states[i_p].shape[1]
                total_emb_length = image_rotary_emb[i_p].shape[1]
                batch_length = audio_attention_mask_list[i_p].shape[2]
                text, batch_hidden_states[i_p] = torch.split(batch_hidden_states[i_p], [total_length-batch_length, batch_length], dim=1)
                text_hidden_states.append(text)
                batch_length_list.append(batch_length)
                image_rotary_emb_list.append(torch.split(image_rotary_emb[i_p], [total_emb_length-batch_length, batch_length], dim=1)[1])
            batch_hidden_states = torch.cat(batch_hidden_states, dim=1)
            batch_norm_hidden_states = self.audio_norm1(batch_hidden_states)
            audio_output = self.audio_attn(
                hidden_states=batch_norm_hidden_states, 
                encoder_hidden_states=audio_hidden_states, 
                encoder_attention_mask=encoder_attention_mask, 
                attention_mask=audio_attention_mask_list, 
                hidden_length=batch_length_list, 
                audio_time_length_list=audio_time_length_list,
                image_rotary_emb=image_rotary_emb_list, 
                audio_rotary_emb=audio_rotary_emb,
                audio_temperature=audio_temperature
            )
            restored_hidden_states_list = []
            for i_p, batch_state in enumerate(audio_output.split(batch_length_list, dim=1)):
                zero_padding = torch.zeros_like(text_hidden_states[i_p], device=batch_state.device)
                restored_hidden_state = torch.cat([zero_padding, batch_state], dim=1)
                restored_hidden_states_list.append(restored_hidden_state)

            hidden_states = torch.cat(restored_hidden_states_list, dim=1) + hidden_states 
     
        norm_hidden_states, gate = self.norm(hidden_states, emb=temb, hidden_length=hidden_length)
        mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
        attn_output = self.attn(
            hidden_states=norm_hidden_states, 
            encoder_hidden_states=None,
            encoder_attention_mask=encoder_attention_mask,
            attention_mask=attention_mask,
            hidden_length=hidden_length,
            image_rotary_emb=image_rotary_emb,
        )
        hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
        hidden_states = gate * self.proj_out(hidden_states)
        hidden_states = residual + hidden_states
        if hidden_states.dtype == torch.float16:
            hidden_states = hidden_states.clip(-65504, 65504)

        return hidden_states

   

    def forward_audio_skip_guidance(
        self,
        hidden_states: torch.FloatTensor,
        temb: torch.FloatTensor,
        encoder_attention_mask=None, 
        attention_mask=None,
        hidden_length=None,
        image_rotary_emb=None,
        audio_hidden_states=None,
        audio_attention_mask_list=None,
        audio_time_length_list=None,
        audio_rotary_emb=None,
        audio_temperature=1.0
    ):
        residual = hidden_states
        
        num_prompt = hidden_states.size(0) // 3
        hidden_states_ptb = hidden_states[2*num_prompt:]
        
        if audio_hidden_states is not None and self.audio_attn is not None:
            batch_hidden_states = list(torch.split(hidden_states, hidden_length, dim=1))
            
            text_hidden_states = []
            batch_length_list = []
            image_rotary_emb_list = []
            for i_p in range(len(batch_hidden_states)):
                total_length = batch_hidden_states[i_p].shape[1]
                total_emb_length = image_rotary_emb[i_p].shape[1]
                batch_length = audio_attention_mask_list[i_p].shape[2]
                text, batch_hidden_states[i_p] = torch.split(batch_hidden_states[i_p], [total_length-batch_length, batch_length], dim=1)
                text_hidden_states.append(text)
                batch_length_list.append(batch_length)
                image_rotary_emb_list.append(torch.split(image_rotary_emb[i_p], [total_emb_length-batch_length, batch_length], dim=1)[1])
            batch_hidden_states = torch.cat(batch_hidden_states, dim=1) 
            batch_norm_hidden_states = self.audio_norm1(batch_hidden_states)
            audio_output = self.audio_attn(
                hidden_states=batch_norm_hidden_states,
                encoder_hidden_states=audio_hidden_states, 
                encoder_attention_mask=encoder_attention_mask, 
                attention_mask=audio_attention_mask_list, 
                hidden_length=batch_length_list, 
                audio_time_length_list=audio_time_length_list, 
                image_rotary_emb=image_rotary_emb_list, 
                audio_rotary_emb=audio_rotary_emb,
                audio_temperature=audio_temperature)
            restored_hidden_states_list = []
            for i_p, batch_state in enumerate(audio_output.split(batch_length_list, dim=1)):
                zero_padding = torch.zeros_like(text_hidden_states[i_p], device=batch_state.device) 
                restored_hidden_state = torch.cat([zero_padding, batch_state], dim=1)
                restored_hidden_states_list.append(restored_hidden_state)

            hidden_states = torch.cat(restored_hidden_states_list, dim=1) + hidden_states 
        
        hidden_states[2*num_prompt:] = hidden_states_ptb
     
        norm_hidden_states, gate = self.norm(hidden_states, emb=temb, hidden_length=hidden_length)
        mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) 

        attn_output = self.attn( 
            hidden_states=norm_hidden_states, 
            encoder_hidden_states=None,
            encoder_attention_mask=encoder_attention_mask, 
            attention_mask=attention_mask,
            hidden_length=hidden_length,
            image_rotary_emb=image_rotary_emb,
        )

        hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
        hidden_states = gate * self.proj_out(hidden_states) 
        hidden_states = residual + hidden_states
        if hidden_states.dtype == torch.float16:
            hidden_states = hidden_states.clip(-65504, 65504)

        return hidden_states


class FluxTransformerBlock(nn.Module):

    def __init__(
        self, 
        dim, 
        num_attention_heads, 
        attention_head_dim, 
        qk_norm="rms_norm", 
        eps=1e-6, 
        use_flash_attn=False,
        use_audio_cross_attn=False,
        audio_cross_attn_dim=1920,
    ):
        super().__init__()

        self.norm1 = AdaLayerNormZero(dim)

        self.norm1_context = AdaLayerNormZero(dim)

        if hasattr(F, "scaled_dot_product_attention"):
            processor = FluxAttnProcessor2_0(use_flash_attn)
            processor_audio = FluxSingleAttnProcessor2_0(use_flash_attn)
        else:
            raise ValueError(
                "The current PyTorch version does not support the `scaled_dot_product_attention` function."
            )
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim,
            heads=num_attention_heads,
            out_dim=dim,
            context_pre_only=False,
            bias=True,
            processor=processor,
            qk_norm=qk_norm,
            eps=eps,
        )
        if use_audio_cross_attn:
            self.audio_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
            self.audio_attn = Attention(
                query_dim=dim,
                cross_attention_dim=audio_cross_attn_dim,
                dim_head=attention_head_dim,
                heads=num_attention_heads,
                out_dim=dim,
                bias=True,
                processor=processor_audio,
                qk_norm=qk_norm,
                eps=eps,
            )
        else:
            self.audio_attn = None
            self.audio_norm1 = None

        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

        self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
        self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")

    def forward(
        self,
        hidden_states: torch.FloatTensor, 
        encoder_hidden_states: torch.FloatTensor,
        encoder_attention_mask: torch.FloatTensor,
        temb: torch.FloatTensor, 
        attention_mask: torch.FloatTensor = None,
        hidden_length: List = None,
        image_rotary_emb=None, 
        audio_hidden_states=None, 
        audio_attention_mask_list=None,
        audio_time_length_list=None,
        audio_rotary_emb=None,
        audio_temperature=1.0
    ):
        if audio_hidden_states is not None and self.audio_attn is not None: 

            batch_hidden_states = list(torch.split(hidden_states, hidden_length, dim=1))
            
            # splitting
            text_hidden_states = []
            batch_length_list = []
            image_rotary_emb_list = []

            for i_p in range(len(batch_hidden_states)):
                total_length = batch_hidden_states[i_p].shape[1] 
                total_emb_length = image_rotary_emb[i_p].shape[1]
                batch_length = audio_attention_mask_list[i_p].shape[2] 
                text, batch_hidden_states[i_p] = torch.split(batch_hidden_states[i_p], [total_length-batch_length, batch_length], dim=1)
                text_hidden_states.append(text)
                batch_length_list.append(batch_length)
                image_rotary_emb_list.append(torch.split(image_rotary_emb[i_p], [total_emb_length-batch_length, batch_length], dim=1)[1])

            batch_hidden_states = torch.cat(batch_hidden_states, dim=1) 
            
            batch_norm_hidden_states = self.audio_norm1(batch_hidden_states)
            audio_output = self.audio_attn(
                hidden_states=batch_norm_hidden_states, 
                encoder_hidden_states=audio_hidden_states,
                encoder_attention_mask=encoder_attention_mask, 
                attention_mask=audio_attention_mask_list, 
                hidden_length=batch_length_list, 
                audio_time_length_list=audio_time_length_list, 
                image_rotary_emb=image_rotary_emb_list, 
                audio_rotary_emb=audio_rotary_emb,
                audio_temperature=audio_temperature
            )
            restored_hidden_states_list = []
            for i_p, batch_state in enumerate(audio_output.split(batch_length_list, dim=1)):
                zero_padding = torch.zeros_like(text_hidden_states[i_p], device=batch_state.device)
                restored_hidden_state = torch.cat([zero_padding, batch_state], dim=1)
                restored_hidden_states_list.append(restored_hidden_state)

            hidden_states = torch.cat(restored_hidden_states_list, dim=1) + hidden_states 
        
        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)

        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
            encoder_hidden_states, emb=temb
        )

        # Attention.
        attn_output, context_attn_output = self.attn(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask, 
            attention_mask=attention_mask,
            hidden_length=hidden_length,
            image_rotary_emb=image_rotary_emb,
        )

        # Process attention outputs for the `hidden_states`.
        attn_output = gate_msa * attn_output
        hidden_states = hidden_states + attn_output

        norm_hidden_states = self.norm2(hidden_states)
        norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp

        ff_output = self.ff(norm_hidden_states)
        ff_output = gate_mlp * ff_output

        hidden_states = hidden_states + ff_output
        
        context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
        encoder_hidden_states = encoder_hidden_states + context_attn_output

        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
        norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]

        context_ff_output = self.ff_context(norm_encoder_hidden_states)
        encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
        
        if encoder_hidden_states.dtype == torch.float16:
            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

        return encoder_hidden_states, hidden_states

    def forward_audio_skip_guidance(
        self,
        hidden_states: torch.FloatTensor, 
        encoder_hidden_states: torch.FloatTensor,
        encoder_attention_mask: torch.FloatTensor,
        temb: torch.FloatTensor, 
        attention_mask: torch.FloatTensor = None, 
        hidden_length: List = None, 
        image_rotary_emb=None, 
        audio_hidden_states=None,
        audio_attention_mask_list=None,
        audio_time_length_list=None,
        audio_rotary_emb=None,
        audio_temperature=1.0
    ):
        
        num_prompt = hidden_states.size(0) // 3
        hidden_states_ptb = hidden_states[2*num_prompt:]

        if audio_hidden_states is not None and self.audio_attn is not None: 
            batch_hidden_states = list(torch.split(hidden_states, hidden_length, dim=1))
            
            # splitting
            text_hidden_states = []
            batch_length_list = []
            image_rotary_emb_list = []

            for i_p in range(len(batch_hidden_states)):
                total_length = batch_hidden_states[i_p].shape[1] 
                total_emb_length = image_rotary_emb[i_p].shape[1]
                batch_length = audio_attention_mask_list[i_p].shape[2] 
                text, batch_hidden_states[i_p] = torch.split(batch_hidden_states[i_p], [total_length-batch_length, batch_length], dim=1)
                text_hidden_states.append(text)
                batch_length_list.append(batch_length)
                image_rotary_emb_list.append(torch.split(image_rotary_emb[i_p], [total_emb_length-batch_length, batch_length], dim=1)[1])

            batch_hidden_states = torch.cat(batch_hidden_states, dim=1) 
            
            batch_norm_hidden_states = self.audio_norm1(batch_hidden_states)
            audio_output = self.audio_attn(
                hidden_states=batch_norm_hidden_states, 
                encoder_hidden_states=audio_hidden_states, 
                encoder_attention_mask=encoder_attention_mask, 
                attention_mask=audio_attention_mask_list, 
                hidden_length=batch_length_list, 
                audio_time_length_list=audio_time_length_list, 
                image_rotary_emb=image_rotary_emb_list, 
                audio_rotary_emb=audio_rotary_emb,
                audio_temperature=audio_temperature
            )
            restored_hidden_states_list = []
            for i_p, batch_state in enumerate(audio_output.split(batch_length_list, dim=1)):
                zero_padding = torch.zeros_like(text_hidden_states[i_p], device=batch_state.device) 
                restored_hidden_state = torch.cat([zero_padding, batch_state], dim=1)
                restored_hidden_states_list.append(restored_hidden_state)

            hidden_states = torch.cat(restored_hidden_states_list, dim=1) + hidden_states 
        
        hidden_states[2*num_prompt:] = hidden_states_ptb
        
        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)

        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
            encoder_hidden_states, emb=temb
        )

        # Attention.
        attn_output, context_attn_output = self.attn(
            hidden_states=norm_hidden_states,
            encoder_hidden_states=norm_encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask, 
            attention_mask=attention_mask,
            hidden_length=hidden_length,
            image_rotary_emb=image_rotary_emb,
        )

        # Process attention outputs for the `hidden_states`.
        attn_output = gate_msa * attn_output
        hidden_states = hidden_states + attn_output

        norm_hidden_states = self.norm2(hidden_states)
        norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp

        ff_output = self.ff(norm_hidden_states)
        ff_output = gate_mlp * ff_output

        hidden_states = hidden_states + ff_output

        # Process attention outputs for the `encoder_hidden_states`.

        context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
        encoder_hidden_states = encoder_hidden_states + context_attn_output

        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
        norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]

        context_ff_output = self.ff_context(norm_encoder_hidden_states)
        encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
        
        if encoder_hidden_states.dtype == torch.float16:
            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)

        return encoder_hidden_states, hidden_states
