from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn
from einops import rearrange
import torch.nn.functional as F
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
from diffusers.utils import is_torch_version
from diffusers.loaders import  PeftAdapterMixin
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor2_0
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero, AdaLayerNormZeroSingle
from diffusers.configuration_utils import ConfigMixin, register_to_config


class CogVideoXControlnetPCD(ModelMixin, ConfigMixin, PeftAdapterMixin):
    _supports_gradient_checkpointing = True
    
    @register_to_config
    def __init__(
        self,
        num_attention_heads: int = 30,
        use_zero_conv: bool = False,
        attention_head_dim: int = 64,
        vae_channels: int = 16,
        in_channels: int = 3,
        downscale_coef: int = 8,
        flip_sin_to_cos: bool = True,
        freq_shift: int = 0,
        time_embed_dim: int = 512,
        num_layers: int = 8,
        dropout: float = 0.0,
        attention_bias: bool = True,
        sample_width: int = 90,
        sample_height: int = 60,
        sample_frames: int = 49,
        patch_size: int = 2,
        temporal_compression_ratio: int = 4,
        max_text_seq_length: int = 226,
        activation_fn: str = "gelu-approximate",
        timestep_activation_fn: str = "silu",
        norm_elementwise_affine: bool = True,
        norm_eps: float = 1e-5,
        spatial_interpolation_scale: float = 1.875,
        temporal_interpolation_scale: float = 1.0,
        use_rotary_positional_embeddings: bool = False,
        use_learned_positional_embeddings: bool = False,
        out_proj_dim: int = None,
        out_proj_dim_zero_init: bool = False,
    ):
        super().__init__()
        inner_dim = num_attention_heads * attention_head_dim

        if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
            raise ValueError(
                "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
                "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
                "issue at https://github.com/huggingface/diffusers/issues."
            )
        
        self.vae_channels = vae_channels
        start_channels = in_channels * (downscale_coef ** 2)
        input_channels = [start_channels, start_channels // 2, start_channels // 4]
        self.unshuffle = nn.PixelUnshuffle(downscale_coef)
        self.use_zero_conv = use_zero_conv
        
        if use_zero_conv:
            self.controlnet_encode_first = nn.Sequential(
                nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0),
                nn.GroupNorm(2, input_channels[1]),
                nn.ReLU(),
            )

            self.controlnet_encode_second = nn.Sequential(
                nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0),
                nn.GroupNorm(2, input_channels[2]),
                nn.ReLU(),
            )
            patch_embed_in_channels = vae_channels + input_channels[2]
        
        else:
            patch_embed_in_channels = vae_channels*2
            
        # 1. Patch embedding
        self.patch_embed = CogVideoXPatchEmbed(
            patch_size=patch_size,
            in_channels=patch_embed_in_channels,
            embed_dim=inner_dim,
            bias=True,
            sample_width=sample_width,
            sample_height=sample_height,
            sample_frames=sample_frames,
            temporal_compression_ratio=temporal_compression_ratio,
            spatial_interpolation_scale=spatial_interpolation_scale,
            temporal_interpolation_scale=temporal_interpolation_scale,
            use_positional_embeddings=not use_rotary_positional_embeddings,
            use_learned_positional_embeddings=use_learned_positional_embeddings,
        )
        
        self.embedding_dropout = nn.Dropout(dropout)

        # 2. Time embeddings
        self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
        self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)

        # 3. Define spatio-temporal transformers blocks
        self.transformer_blocks = nn.ModuleList(
            [
                CogVideoXBlock(
                    dim=inner_dim,
                    num_attention_heads=num_attention_heads,
                    attention_head_dim=attention_head_dim,
                    time_embed_dim=time_embed_dim,
                    dropout=dropout,
                    activation_fn=activation_fn,
                    attention_bias=attention_bias,
                    norm_elementwise_affine=norm_elementwise_affine,
                    norm_eps=norm_eps,
                )
                for _ in range(num_layers)
            ]
        )

        self.out_projectors = None
        if out_proj_dim is not None:
            self.out_projectors = nn.ModuleList(
                [nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)]
            )
            if out_proj_dim_zero_init:
                for out_projector in self.out_projectors:
                    self.zeros_init_linear(out_projector)   
            
        self.gradient_checkpointing = False
    
    def zeros_init_linear(self, linear: nn.Module):
        if isinstance(linear, (nn.Linear, nn.Conv1d)):
            if hasattr(linear, "weight"):
                nn.init.zeros_(linear.weight)
            if hasattr(linear, "bias"):
                nn.init.zeros_(linear.bias)
        
    def _set_gradient_checkpointing(self, module, value=False):
        self.gradient_checkpointing = value

    def compress_time(self, x, num_frames):
        x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
        batch_size, frames, channels, height, width = x.shape
        x = rearrange(x, 'b f c h w -> (b h w) c f')
        
        if x.shape[-1] % 2 == 1:
            x_first, x_rest = x[..., 0], x[..., 1:]
            if x_rest.shape[-1] > 0:
                x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)

            x = torch.cat([x_first[..., None], x_rest], dim=-1)
        else:
            x = F.avg_pool1d(x, kernel_size=2, stride=2)
        x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width)
        return x
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        controlnet_states: Tuple[torch.Tensor, torch.Tensor],
        timestep: Union[int, float, torch.LongTensor],
        controlnet_output_mask: Optional[torch.Tensor] = None,
        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        timestep_cond: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ):
        
        controlnet_states, anchor_states = controlnet_states
        batch_size, num_frames, channels, height, width = controlnet_states.shape
        # 0. Controlnet encoder
        if self.use_zero_conv:
            # controlnet_states = rearrange(controlnet_states, 'b f c h w -> (b f) c h w')
            # controlnet_states = self.unshuffle(controlnet_states)
            # controlnet_states = self.controlnet_encode_first(controlnet_states)
            # controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames) 
            # num_frames = controlnet_states.shape[0] // batch_size

            # controlnet_states = self.controlnet_encode_second(controlnet_states)
            # controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames) 
            # controlnet_states = rearrange(controlnet_states, '(b f) c h w -> b f c h w', b=batch_size)
            # hidden_states = torch.cat([hidden_states, anchor_states,controlnet_states], dim=2)

            anchor_states = rearrange(anchor_states, 'b f c h w -> (b f) c h w')
            anchor_states = self.unshuffle(anchor_states)
            anchor_states = self.controlnet_encode_first(anchor_states)
            anchor_states = self.compress_time(anchor_states, num_frames=num_frames) 
            num_frames = anchor_states.shape[0] // batch_size

            anchor_states = self.controlnet_encode_second(anchor_states)
            anchor_states = self.compress_time(anchor_states, num_frames=num_frames) 
            anchor_states = rearrange(anchor_states, '(b f) c h w -> b f c h w', b=batch_size)


            hidden_states = torch.cat([hidden_states, anchor_states], dim=2)
        
        else:
            hidden_states = torch.cat([hidden_states, anchor_states], dim=2)
        
        # controlnet_states = self.controlnext_encoder(controlnet_states, timestep=timestep)
        # 1. Time embedding
        timesteps = timestep
        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=hidden_states.dtype)
        emb = self.time_embedding(t_emb, timestep_cond)
        
        hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
        hidden_states = self.embedding_dropout(hidden_states)


        text_seq_length = encoder_hidden_states.shape[1]
        encoder_hidden_states = hidden_states[:, :text_seq_length]
        hidden_states = hidden_states[:, text_seq_length:]

        
        controlnet_hidden_states = ()
        # 3. Transformer blocks
        for i, block in enumerate(self.transformer_blocks):
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    encoder_hidden_states,
                    emb,
                    image_rotary_emb,
                    **ckpt_kwargs,
                )
            else:
                hidden_states, encoder_hidden_states = block(
                    hidden_states=hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    temb=emb,
                    image_rotary_emb=image_rotary_emb,
                )
                
            if self.out_projectors is not None:
                if controlnet_output_mask is not None:
                    controlnet_hidden_states += (self.out_projectors[i](hidden_states) * controlnet_output_mask,)
                else:
                    controlnet_hidden_states += (self.out_projectors[i](hidden_states),)
            else:
                controlnet_hidden_states += (hidden_states,)

        if not return_dict:
            return (controlnet_hidden_states,)
        return Transformer2DModelOutput(sample=controlnet_hidden_states)