import torch
import torch.nn as nn
from einops import rearrange
from typing import List, Union, Optional, Dict, Any, Callable
from diffusers.models.attention_processor import Attention, F
from diffusers.models.attention_processor import (
    Attention,
    AttentionProcessor,
    FluxAttnProcessor2_0,
    FluxAttnProcessor2_0_NPU,
    FusedFluxAttnProcessor2_0,
)
from diffusers.models.attention import FeedForward
from diffusers.utils.import_utils import is_torch_npu_available

class CameraAdapter(nn.Module):
    def __init__(self, input_channels, camera_dim=9):
        super(CameraAdapter, self).__init__()
        self.camera_mlp = nn.Sequential(
            nn.Linear(camera_dim, 100),
            nn.ReLU(),
            nn.Linear(100, input_channels),
            nn.ReLU(),
            nn.Linear(input_channels, input_channels),
            nn.Tanh()
        )

    def forward(self, x):
        x_proj = self.camera_mlp(x)

        return x_proj


class TransformerBlock(nn.Module):
    def __init__(
        self, 
        dim, 
        num_attention_heads, 
        attention_head_dim,
        attention_bias=False,
        attention_out_bias=True,
        norm_elementwise_affine=True,
        norm_eps=1e-5,
        dropout=0.0,
        activation_fn="geglu",
        final_dropout: bool = False,
        ff_inner_dim: Optional[int] = None,
        ff_bias: bool = True,
        rope: bool = False,
    ):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
        self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
        self.ff = FeedForward(
            dim,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
            inner_dim=ff_inner_dim,
            bias=ff_bias,
        )

        if rope:
            processor = FluxAttnProcessor2_0()
            self.attn = Attention(
                query_dim=dim,
                cross_attention_dim=None,
                dim_head=attention_head_dim,
                heads=num_attention_heads,
                out_dim=dim,
                bias=True,
                processor=processor,
                qk_norm="rms_norm",
                eps=1e-6,
                pre_only=True,
            )
        else:
            self.attn = Attention(
                query_dim=dim,
                heads=num_attention_heads,
                dim_head=attention_head_dim,
                dropout=dropout,
                bias=attention_bias,
                upcast_attention=False,
                out_bias=attention_out_bias,
            )

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        image_rotary_emb=None,
    ):
        norm_hidden_states = self.norm1(hidden_states)
        attn_output = self.attn(
            hidden_states=norm_hidden_states,
            # image_rotary_emb=image_rotary_emb,
            **({"image_rotary_emb": image_rotary_emb} if image_rotary_emb is not None else {})
        )

        hidden_states = attn_output + hidden_states
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)

        norm_hidden_states = self.norm3(hidden_states)
        ff_output = self.ff(norm_hidden_states)

        hidden_states = ff_output + hidden_states
        if hidden_states.ndim == 4:
            hidden_states = hidden_states.squeeze(1)

        return hidden_states


class Spatial_Controller(nn.Module):
    def __init__(
        self, 
        dim, 
        num_attention_heads, 
        attention_head_dim,
        norm_elementwise_affine=True,
        norm_eps=1e-5,
        dropout=0.0,
        activation_fn="geglu",
        final_dropout: bool = False,
        ff_inner_dim: Optional[int] = None,
        ff_bias: bool = True,
    ):
        super().__init__()

        self.spatial_transformer = TransformerBlock(
            dim=dim,
            num_attention_heads=num_attention_heads,
            attention_head_dim=attention_head_dim,
            norm_elementwise_affine=norm_elementwise_affine,
            norm_eps=norm_eps,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
            ff_inner_dim=ff_inner_dim,
            ff_bias=ff_bias,
            rope=True,
        )

        self.w_spatial = nn.Linear(dim, dim)
        nn.init.zeros_(self.w_spatial.weight)
        nn.init.zeros_(self.w_spatial.bias)

    def forward(
        self,
        hidden_states: torch.FloatTensor,
        spatial_rotary_emb=None,
    ):
        B, L, N, C = hidden_states.shape
        hidden_states = rearrange(hidden_states, 'b l n c -> (b l) n c')

        hidden_states_ = self.spatial_transformer(hidden_states, spatial_rotary_emb)
        w_spatial = torch.sigmoid(self.w_spatial(hidden_states_))
        hidden_states = hidden_states_ * w_spatial
        
        return rearrange(hidden_states, '(b l) n c -> b l n c', b=B, n=N)
        

class Inter_Controller(nn.Module):
    def __init__(
        self, 
        dim, 
        num_attention_heads, 
        attention_head_dim,
        norm_elementwise_affine=True,
        norm_eps=1e-5,
        dropout=0.0,
        activation_fn="geglu",
        final_dropout: bool = False,
        ff_inner_dim: Optional[int] = None,
        ff_bias: bool = True,
    ):
        super().__init__()

        self.spatial_transformer = TransformerBlock(
            dim=dim,
            num_attention_heads=num_attention_heads,
            attention_head_dim=attention_head_dim,
            norm_elementwise_affine=norm_elementwise_affine,
            norm_eps=norm_eps,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
            ff_inner_dim=ff_inner_dim,
            ff_bias=ff_bias,
            rope=True,
        )

        self.w_spatial = nn.Linear(dim, dim)
        nn.init.zeros_(self.w_spatial.weight)
        nn.init.zeros_(self.w_spatial.bias)

        self.layer_transformer = TransformerBlock(
            dim=dim,
            num_attention_heads=num_attention_heads,
            attention_head_dim=attention_head_dim,
            norm_elementwise_affine=norm_elementwise_affine,
            norm_eps=norm_eps,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
            ff_inner_dim=ff_inner_dim,
            ff_bias=ff_bias,
        )
        self.w_layer = nn.Linear(dim, dim)
        
    
    def forward(
        self,
        hidden_states: torch.FloatTensor,
        spatial_rotary_emb=None,
        layer_rotary_emb=None,
    ):
        B, L, N, C = hidden_states.shape
        hidden_states = rearrange(hidden_states, 'b l n c -> (b l) n c')

        hidden_states_ = self.spatial_transformer(hidden_states, spatial_rotary_emb)
        w_spatial = torch.sigmoid(self.w_spatial(hidden_states_))
        hidden_states = hidden_states_ * w_spatial
        
        hidden_states = rearrange(hidden_states, '(b l) n c -> (b n) l c', b=B, l=L)

        hidden_states_ = self.layer_transformer(hidden_states)
        w_layer = torch.softmax(self.w_layer(hidden_states_), dim=1)
        hidden_states = hidden_states_ * w_layer
        
        return rearrange(hidden_states, '(b n) l c -> b l n c', b=B, n=N)
        