# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import flax.linen as nn
import jax.numpy as jnp

from .attention_flax import FlaxTransformer2DModel
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
from .utils.common_types import BlockSizes

class FlaxCrossAttnDownBlock2D(nn.Module):
    r"""
    Cross Attention 2D Downsizing block - original architecture from Unet transformers:
    https://arxiv.org/abs/2103.06104

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):
            Number of attention heads of each spatial transformer block
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
        attention_kernel (`str`, *optional*, defaults to `dot_product`)
            Attention mechanism to be used.
        flash_min_seq_length (`int`, *optional*, defaults to 4096)
            Minimum seq length required to apply flash attention.
        flash_block_sizes (`BlockSizes`, *optional*, defaults to None)
            Overrides default block sizes for flash attention.
        mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
            jax mesh is required if attention is set to flash.
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    num_attention_heads: int = 1
    add_downsample: bool = True
    use_linear_projection: bool = False
    only_cross_attention: bool = False
    use_memory_efficient_attention: bool = False
    split_head_dim: bool = False
    attention_kernel: str = "dot_product"
    flash_min_seq_length: int = 4096
    flash_block_sizes: BlockSizes = None
    mesh: jax.sharding.Mesh = None
    dtype: jnp.dtype = jnp.float32
    transformer_layers_per_block: int = 1
    norm_num_groups: int = 32
    act_fn: str = "silu"
    conv3d: bool = False
    cross_attention_dim: int = 1280

    def setup(self):
        resnets = []
        attentions = []

        for i in range(self.num_layers):
            in_channels = self.in_channels if i == 0 else self.out_channels

            res_block = FlaxResnetBlock2D(
                in_channels=in_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
            )
            resnets.append(res_block)

            attn_block = FlaxTransformer2DModel(
                in_channels=self.out_channels,
                n_heads=self.num_attention_heads,
                d_head=self.out_channels // self.num_attention_heads,
                depth=self.transformer_layers_per_block,
                use_linear_projection=self.use_linear_projection,
                only_cross_attention=self.only_cross_attention,
                use_memory_efficient_attention=self.use_memory_efficient_attention,
                split_head_dim=self.split_head_dim,
                attention_kernel=self.attention_kernel,
                flash_min_seq_length=self.flash_min_seq_length,
                flash_block_sizes=self.flash_block_sizes,
                mesh=self.mesh,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
                cross_attention_dim=self.cross_attention_dim,
            )
            attentions.append(attn_block)

        self.resnets = resnets
        self.attentions = attentions

        if self.add_downsample:
            self.downsamplers_0 = FlaxDownsample2D(self.out_channels, 
                                    dtype=self.dtype,
                                    conv3d=self.conv3d,
                                )

    def __call__(self, hidden_states, temb, encoder_hidden_states=None, deterministic=True):
        output_states = ()

        for resnet, attn in zip(self.resnets, self.attentions):
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
            output_states += (hidden_states,)

        if self.add_downsample:
            hidden_states = self.downsamplers_0(hidden_states)
            output_states += (hidden_states,)

        return hidden_states, output_states


class FlaxDownBlock2D(nn.Module):
    r"""
    Flax 2D downsizing block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
    in_channels: int
    out_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    add_downsample: bool = True
    dtype: jnp.dtype = jnp.float32
    norm_num_groups: int = 32
    act_fn: str = "silu"
    conv3d: bool = False

    def setup(self):
        resnets = []

        for i in range(self.num_layers):
            in_channels = self.in_channels if i == 0 else self.out_channels

            res_block = FlaxResnetBlock2D(
                in_channels=in_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
            )
            resnets.append(res_block)
        self.resnets = resnets

        if self.add_downsample:
            self.downsamplers_0 = FlaxDownsample2D(self.out_channels, 
                                    dtype=self.dtype,
                                    conv3d=self.conv3d,
                                )

    def __call__(self, hidden_states, temb, deterministic=True):
        output_states = ()

        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            output_states += (hidden_states,)

        if self.add_downsample:
            hidden_states = self.downsamplers_0(hidden_states)
            output_states += (hidden_states,)

        return hidden_states, output_states


class FlaxCrossAttnUpBlock2D(nn.Module):
    r"""
    Cross Attention 2D Upsampling block - original architecture from Unet transformers:
    https://arxiv.org/abs/2103.06104

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):
            Number of attention heads of each spatial transformer block
        add_upsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add upsampling layer before each final output
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
        attention_kernel (`str`, *optional*, defaults to `dot_product`)
            Attention mechanism to be used.
        flash_min_seq_length (`int`, *optional*, defaults to 4096)
            Minimum seq length required to apply flash attention.
        flash_block_sizes (`BlockSizes`, *optional*, defaults to None)
            Overrides default block sizes for flash attention.
        mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
            jax mesh is required if attention is set to flash.
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
    in_channels: int
    out_channels: int
    prev_output_channel: int
    dropout: float = 0.0
    num_layers: int = 1
    num_attention_heads: int = 1
    add_upsample: bool = True
    use_linear_projection: bool = False
    only_cross_attention: bool = False
    use_memory_efficient_attention: bool = False
    split_head_dim: bool = False
    attention_kernel: str = "dot_product"
    flash_min_seq_length: int = 4096
    flash_block_sizes: BlockSizes = None
    mesh: jax.sharding.Mesh = None
    dtype: jnp.dtype = jnp.float32
    transformer_layers_per_block: int = 1
    norm_num_groups: int = 32
    act_fn: str = "silu"
    conv3d: bool = False
    up_skip: bool = False
    cross_attention_dim: int = 1280

    def setup(self):
        resnets = []
        attentions = []

        for i in range(self.num_layers):
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels

            res_block = FlaxResnetBlock2D(
                in_channels=resnet_in_channels if self.up_skip else resnet_in_channels + res_skip_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
            )
            resnets.append(res_block)

            attn_block = FlaxTransformer2DModel(
                in_channels=self.out_channels,
                n_heads=self.num_attention_heads,
                d_head=self.out_channels // self.num_attention_heads,
                depth=self.transformer_layers_per_block,
                use_linear_projection=self.use_linear_projection,
                only_cross_attention=self.only_cross_attention,
                use_memory_efficient_attention=self.use_memory_efficient_attention,
                split_head_dim=self.split_head_dim,
                attention_kernel=self.attention_kernel,
                flash_min_seq_length=self.flash_min_seq_length,
                flash_block_sizes=self.flash_block_sizes,
                mesh=self.mesh,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
                cross_attention_dim=self.cross_attention_dim,
            )
            attentions.append(attn_block)

        self.resnets = resnets
        self.attentions = attentions

        if self.add_upsample:
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, 
                                    dtype=self.dtype,
                                    conv3d=self.conv3d,
                                )

    def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
        for resnet, attn in zip(self.resnets, self.attentions):
            # pop res hidden states
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]

            if not self.up_skip:
                hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
            if self.up_skip:
                hidden_states += res_hidden_states

        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)

        return hidden_states


class FlaxUpBlock2D(nn.Module):
    r"""
    Flax 2D upsampling block

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        out_channels (:obj:`int`):
            Output channels
        prev_output_channel (:obj:`int`):
            Output channels from the previous block
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        add_downsample (:obj:`bool`, *optional*, defaults to `True`):
            Whether to add downsampling layer before each final output
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
    in_channels: int
    out_channels: int
    prev_output_channel: int
    dropout: float = 0.0
    num_layers: int = 1
    add_upsample: bool = True
    dtype: jnp.dtype = jnp.float32
    norm_num_groups: int = 32
    act_fn: str = "silu"
    conv3d: bool = False
    up_skip: bool = False

    def setup(self):
        resnets = []

        for i in range(self.num_layers):
            res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels

            res_block = FlaxResnetBlock2D(
                in_channels=resnet_in_channels if self.up_skip else resnet_in_channels + res_skip_channels,
                out_channels=self.out_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
            )
            resnets.append(res_block)

        self.resnets = resnets

        if self.add_upsample:
            self.upsamplers_0 = FlaxUpsample2D(self.out_channels, 
                                    dtype=self.dtype,
                                    conv3d=self.conv3d,
                                )

    def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
        for resnet in self.resnets:
            # pop res hidden states
            res_hidden_states = res_hidden_states_tuple[-1]
            res_hidden_states_tuple = res_hidden_states_tuple[:-1]

            if not self.up_skip:
                hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)

            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)

            if self.up_skip:
                hidden_states += res_hidden_states

        if self.add_upsample:
            hidden_states = self.upsamplers_0(hidden_states)

        return hidden_states


class FlaxUNetMidBlock2DCrossAttn(nn.Module):
    r"""
    Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104

    Parameters:
        in_channels (:obj:`int`):
            Input channels
        dropout (:obj:`float`, *optional*, defaults to 0.0):
            Dropout rate
        num_layers (:obj:`int`, *optional*, defaults to 1):
            Number of attention blocks layers
        num_attention_heads (:obj:`int`, *optional*, defaults to 1):
            Number of attention heads of each spatial transformer block
        use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
            enable memory efficient attention https://arxiv.org/abs/2112.05682
        split_head_dim (`bool`, *optional*, defaults to `False`):
            Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
            enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
        attention_kernel (`str`, *optional*, defaults to `dot_product`)
            Attention mechanism to be used.
        flash_min_seq_length (`int`, *optional*, defaults to 4096)
            Minimum seq length required to apply flash attention.
        flash_block_sizes (`BlockSizes`, *optional*, defaults to None)
            Overrides default block sizes for flash attention.
        mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
            jax mesh is required if attention is set to flash.
        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
            Parameters `dtype`
    """
    in_channels: int
    dropout: float = 0.0
    num_layers: int = 1
    num_attention_heads: int = 1
    use_linear_projection: bool = False
    use_memory_efficient_attention: bool = False
    split_head_dim: bool = False
    attention_kernel: str = "dot_product"
    flash_min_seq_length: int = 4096
    flash_block_sizes: BlockSizes = None
    mesh: jax.sharding.Mesh = None
    dtype: jnp.dtype = jnp.float32
    transformer_layers_per_block: int = 1
    norm_num_groups: int = 32
    act_fn: str = "silu"
    conv3d: bool = False
    cross_attention_dim: int = 1280

    def setup(self):
        # there is always at least one resnet
        resnets = [
            FlaxResnetBlock2D(
                in_channels=self.in_channels,
                out_channels=self.in_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
            )
        ]

        attentions = []

        for _ in range(self.num_layers):
            attn_block = FlaxTransformer2DModel(
                in_channels=self.in_channels,
                n_heads=self.num_attention_heads,
                d_head=self.in_channels // self.num_attention_heads,
                depth=self.transformer_layers_per_block,
                use_linear_projection=self.use_linear_projection,
                use_memory_efficient_attention=self.use_memory_efficient_attention,
                split_head_dim=self.split_head_dim,
                attention_kernel=self.attention_kernel,
                flash_min_seq_length=self.flash_min_seq_length,
                flash_block_sizes=self.flash_block_sizes,
                mesh=self.mesh,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
                cross_attention_dim=self.cross_attention_dim,
            )
            attentions.append(attn_block)

            res_block = FlaxResnetBlock2D(
                in_channels=self.in_channels,
                out_channels=self.in_channels,
                dropout_prob=self.dropout,
                dtype=self.dtype,
                norm_num_groups=self.norm_num_groups,
                act_fn=self.act_fn,
                conv3d=self.conv3d,
            )
            resnets.append(res_block)

        self.resnets = resnets
        self.attentions = attentions

    def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
        hidden_states = self.resnets[0](hidden_states, temb)
        for attn, resnet in zip(self.attentions, self.resnets[1:]):
            hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
            hidden_states = resnet(hidden_states, temb, deterministic=deterministic)

        return hidden_states
