from typing import Dict, Optional, Tuple, Union, Any
import copy

from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers import AutoencoderKL
from diffusers.models.autoencoders.vae import Decoder, Encoder, DecoderOutput
from diffusers.models.embeddings import PatchEmbed
from diffusers.models.attention_processor import Attention
from diffusers.models.attention import FeedForward, _chunked_feed_forward
from diffusers.configuration_utils import register_to_config
from diffusers.utils import is_torch_version


def zero_module(module):
    # Zero out the parameters of a module and return it.
    for p in module.parameters():
        p.detach().zero_()
    return module


class DetailEncoder(Encoder):
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        act_fn: str = "silu",
        double_z: bool = True,
        mid_block_add_attention=True,
    ):
        super().__init__(
            in_channels,
            out_channels,
            down_block_types,
            block_out_channels,
            layers_per_block,
            norm_num_groups,
            act_fn,
            double_z,
            mid_block_add_attention,
        )

    def pyramid_feature_forward(self, sample: torch.Tensor) -> torch.Tensor:

        samples = []
        # shape: (1,3,512,512)
        sample = self.conv_in(sample)
        # shape: (1,128,512,512)
        samples.append(sample)

        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            # down
            if is_torch_version(">=", "1.11.0"):
                for down_block in self.down_blocks:
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(down_block), sample, use_reentrant=False
                    )
                    samples.append(sample)
                # middle
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block), sample, use_reentrant=False
                )
                samples.append(sample)
            else:
                for down_block in self.down_blocks:
                    sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
                    samples.append(sample)
                # middle
                sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
                samples.append(sample)

        else:
            # down
            for down_block in self.down_blocks:
                sample = down_block(sample)
                # shape 0: (1,128,256,256)
                # shape 1: (1,256,128,128)
                # shape 2: (1,512,64,64)
                # shape 3: (1,512,64,64)
                samples.append(sample)

            # middle
            sample = self.mid_block(sample)
            # shape: (1,512,64,64)
            samples.append(sample)

        # # post-process
        # sample = self.conv_norm_out(sample)
        # sample = self.conv_act(sample)
        # # shape: (1,512,64,64)
        # sample = self.conv_out(sample)
        # # shape: (1,8,64,64)

        return samples


class RefCrossAttn(nn.Module):

    def __init__(
        self,
        dim: int,
        num_attention_heads: int = 8,
        attention_head_dim: int = 16,
        dropout=0.0,
        cross_attention_dim: Optional[int] = None,
        activation_fn: str = "geglu",
        norm_elementwise_affine: bool = True,
        norm_eps: float = 1e-5,
        final_dropout: bool = False,
        ff_inner_dim: Optional[int] = None,
        ff_bias: bool = True,
    ):
        super().__init__()

        # Cross-Attn
        self.norm1 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

        self.attn1 = Attention(
            query_dim=dim,
            cross_attention_dim=cross_attention_dim,
            heads=num_attention_heads,
            dim_head=attention_head_dim,
            dropout=dropout
        )

        # Feed-forward
        self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)

        self.ff = FeedForward(
            dim,
            dropout=dropout,
            activation_fn=activation_fn,
            final_dropout=final_dropout,
            inner_dim=ff_inner_dim,
            bias=ff_bias,
        )

        # let chunk size default to None
        self._chunk_size = None
        self._chunk_dim = 0

    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
        # Sets chunk feed-forward
        self._chunk_size = chunk_size
        self._chunk_dim = dim

    def forward(
        self,
        hidden_states: torch.Tensor,
        cross_hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size = hidden_states.shape[0]

        norm_hidden_states = self.norm1(hidden_states)
        norm_cross_hidden_states = self.norm1(cross_hidden_states)

        norm_cross_hidden_states = torch.cat([norm_cross_hidden_states]*batch_size, dim=0)

        attn_output = self.attn1(
            norm_hidden_states,
            encoder_hidden_states=norm_cross_hidden_states,
        )
        hidden_states = attn_output + hidden_states

        norm_hidden_states = self.norm2(hidden_states)

        if self._chunk_size is not None:
            ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
        else:
            ff_output = self.ff(norm_hidden_states)

        hidden_states = ff_output + hidden_states

        return hidden_states


class RefAttnModel(nn.Module):
    def __init__(
        self,
        num_attention_heads: int = 8,
        attention_head_dim: int = 16,
        in_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        num_layers: int = 1,
        dropout: float = 0.0,
        sample_size: Optional[int] = None,
        patch_size: int = 4,
        activation_fn: str = "geglu",
        norm_elementwise_affine: bool = True,
        norm_eps: float = 1e-5,
    ):
        super().__init__()
        assert in_channels is not None and patch_size is not None
        self.gradient_checkpointing = False
        
        self.num_attention_heads = num_attention_heads
        self.attention_head_dim = attention_head_dim
        self.inner_dim = self.num_attention_heads * self.attention_head_dim
        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None else out_channels
        self.dropout = dropout
        self.activation_fn = activation_fn
        self.norm_elementwise_affine = norm_elementwise_affine
        self.norm_eps = norm_eps
        self.num_layers = num_layers

        self.output_gating = nn.Sequential(
            nn.Conv2d(2*self.in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(64, self.out_channels, kernel_size=3, padding=1),
        )
        
        self.warping_convolution = nn.Sequential(
            nn.Conv2d(2*self.in_channels, 256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(64, 16, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(16, 4, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(4, 2, kernel_size=3, padding=1),
        )

        # self.warp_out = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1)

        self.patch_size = patch_size
        self.pos_embed = PatchEmbed(
            height=sample_size,
            width=sample_size,
            patch_size=patch_size,
            in_channels=self.in_channels,
            embed_dim=self.inner_dim,
        )
        self.transformer_blocks = nn.ModuleList(
            [
                RefCrossAttn(
                    self.inner_dim,
                    self.num_attention_heads,
                    self.attention_head_dim,
                    dropout=self.dropout,
                    cross_attention_dim=self.inner_dim,
                    activation_fn=self.activation_fn,
                    norm_elementwise_affine=self.norm_elementwise_affine,
                    norm_eps=self.norm_eps,
                )
                for _ in range(self.num_layers)
            ]
        )
        self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
        self.proj_out = nn.Linear(
            self.inner_dim, self.patch_size * self.patch_size * self.out_channels
        )
        # zero init
        self.output_gating[-1] = zero_module(self.output_gating[-1])
        self.warping_convolution[-1] = zero_module(self.warping_convolution[-1])
        # self.warp_out = zero_module(self.warp_out)
        # self.proj_out = zero_module(self.proj_out)

    
    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
    ):
        # hidden_states: (b,c,h,w)
        # encoder_hidden_states: (b,c,h,w)
        skip_connection = hidden_states


        if hidden_states.shape[-2] != encoder_hidden_states.shape[-2] or hidden_states.shape[-1] != encoder_hidden_states.shape[-1]:
            _encoder_hidden_states = torch.nn.functional.interpolate(encoder_hidden_states, (hidden_states.shape[-2], hidden_states.shape[-1]), mode='bilinear')
        else:
            _encoder_hidden_states = encoder_hidden_states
        cat_features = torch.cat([hidden_states, _encoder_hidden_states], dim=1)
        ## output gating
        gates = self.output_gating(cat_features)
        gates = torch.tanh(gates)

        ## warping branch
        local_motion = self.warping_convolution(cat_features)
        local_motion = rearrange(local_motion, "b c h w -> b h w c")

        identity_grid = F.affine_grid(torch.stack([torch.tensor([[1, 0, 0], [0, 1, 0]], dtype=cat_features.dtype).to(cat_features.device).view(2,3)]*cat_features.shape[0], dim=0), _encoder_hidden_states.size(), align_corners=False)
        grid = identity_grid + local_motion

        warp_encoder_hidden_states = F.grid_sample(_encoder_hidden_states, grid, align_corners=False)

        # warp_encoder_hidden_states = self.warp_out(warp_encoder_hidden_states)


        ## cross-attention branch 
        # 1. Input
        height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
        hidden_states = self.pos_embed(hidden_states)
        encoder_hidden_states = self.pos_embed(encoder_hidden_states)

        # 2. Blocks
        for block in self.transformer_blocks:
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
                hidden_states = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    encoder_hidden_states,
                    **ckpt_kwargs,
                )
            else:
                hidden_states = block(
                    hidden_states,
                    encoder_hidden_states,
                )

        # 3. Output
        hidden_states = self.norm_out(hidden_states)
        hidden_states = self.proj_out(hidden_states)
        hidden_states = hidden_states.reshape(
            shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
        )
        hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
        
        ## combination of two branch
        hidden_states = hidden_states.reshape(
            shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
        ) 
        
        # output = (1 - gates[:,0,:,:].unsqueeze(1)) * skip_connection + gates[:,0,:,:].unsqueeze(1) * ((1 - gates[:,1,:,:].unsqueeze(1)) * warp_encoder_hidden_states + gates[:,1,:,:].unsqueeze(1) * hidden_states)
        output = skip_connection + gates * (warp_encoder_hidden_states + hidden_states)

        return output


class DetailDecoder(Decoder):
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        act_fn: str = "silu",
        norm_type: str = "group",  # group, spatial
        mid_block_add_attention=True,
    ):
        super().__init__(
            in_channels,
            out_channels,
            up_block_types,
            block_out_channels,
            layers_per_block,
            norm_num_groups,
            act_fn,
            norm_type,
            mid_block_add_attention,
        )

        pyramid_channels = copy.deepcopy(block_out_channels)
        pyramid_channels.append(block_out_channels[-1])
        pyramid_channels.append(block_out_channels[-1])
        pyramid_channels = list(reversed(pyramid_channels))
        ref_pyramid_channels = copy.deepcopy(block_out_channels)
        ref_pyramid_channels.append(block_out_channels[-1])
        ref_pyramid_channels.insert(0, block_out_channels[0])
        ref_pyramid_channels = list(reversed(ref_pyramid_channels))
        
        resolution = 64
        resolutions = []
        for i in range(len(pyramid_channels)):
            if i in [2,3,4]:
                resolution = resolution * 2
            resolutions.append(resolution)

        self.channel_matching_convs = nn.ModuleList([])
        self.ref_attn_blocks = nn.ModuleList([])
        for channel, ref_channel, resolution in zip(pyramid_channels, ref_pyramid_channels, resolutions):
            if channel != ref_channel:
                conv = nn.Conv2d(ref_channel, channel, kernel_size=1)
            else:
                conv = None
            block = RefAttnModel(in_channels=channel, sample_size=resolution)
            self.channel_matching_convs.append(conv)
            self.ref_attn_blocks.append(block)
    
    def forward(
        self,
        sample: torch.Tensor,
        ref_sample_list: Tuple[torch.Tensor] = None,
        latent_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:

        # sample: (1,4,64,64)
        sample = self.conv_in(sample)
        # add ref attn
        if ref_sample_list is not None:
            if self.channel_matching_convs[0] is not None:
                ref_sample = self.channel_matching_convs[0](ref_sample_list[-1])
            else:
                ref_sample = ref_sample_list[-1]
            sample = self.ref_attn_blocks[0](sample, ref_sample)
        # sample: (1,512,64,64)

        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            if is_torch_version(">=", "1.11.0"):
                # middle
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block),
                    sample,
                    latent_embeds,
                    use_reentrant=False,
                )
                sample = sample.to(upscale_dtype)
                # add ref attn
                if ref_sample_list is not None:
                    if self.channel_matching_convs[1] is not None:
                        ref_sample = torch.utils.checkpoint.checkpoint(
                            create_custom_forward(self.channel_matching_convs[1]),
                            ref_sample_list[-2],
                            use_reentrant=False,
                        )
                    else:
                        ref_sample = ref_sample_list[-2]
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(self.ref_attn_blocks[1]),
                        sample,
                        ref_sample,
                        use_reentrant=False,
                    )

                # up
                if ref_sample_list is not None:
                    ref_sample_ind = -3
                for up_block, conv, ref_attn_block in zip(self.up_blocks, self.channel_matching_convs[2:], self.ref_attn_blocks[2:]):
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(up_block),
                        sample,
                        latent_embeds,
                        use_reentrant=False,
                    )
                    if ref_sample_list is not None:
                        if conv is not None:
                            ref_sample = torch.utils.checkpoint.checkpoint(
                                create_custom_forward(conv),
                                ref_sample_list[ref_sample_ind],
                                use_reentrant=False,
                            )
                        else:
                            ref_sample = ref_sample_list[ref_sample_ind]
                        sample = torch.utils.checkpoint.checkpoint(
                            create_custom_forward(ref_attn_block),
                            sample,
                            ref_sample,
                            use_reentrant=False,
                        )
                        ref_sample_ind = ref_sample_ind - 1

            else:
                # middle
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block), sample, latent_embeds
                )
                sample = sample.to(upscale_dtype)
                # add ref attn
                if ref_sample_list is not None:
                    if self.channel_matching_convs[1] is not None:
                        ref_sample = torch.utils.checkpoint.checkpoint(
                            create_custom_forward(self.channel_matching_convs[1]),
                            ref_sample_list[-2],
                        )
                    else:
                        ref_sample = ref_sample_list[-2]
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(self.ref_attn_block[1]),
                        sample,
                        ref_sample,
                    )

                # up
                if ref_sample_list is not None:
                    ref_sample_ind = -3
                for up_block, conv, ref_attn_block in zip(self.up_blocks, self.channel_matching_convs[2:], self.ref_attn_blocks[2:]):
                    sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
                    # add ref attn
                    if ref_sample_list is not None:
                        if conv is not None:
                            ref_sample = torch.utils.checkpoint.checkpoint(
                                create_custom_forward(conv),
                                ref_sample_list[ref_sample_ind],
                            )
                        else:
                            ref_sample = ref_sample_list[ref_sample_ind]
                        sample = torch.utils.checkpoint.checkpoint(
                            create_custom_forward(ref_attn_block),
                            sample,
                            ref_sample,
                        )
                        ref_sample_ind = ref_sample_ind - 1
        else:
            # middle
            sample = self.mid_block(sample, latent_embeds)
            # sample: (1,512,64,64)
            sample = sample.to(upscale_dtype)
            # add ref attn
            if ref_sample_list is not None:
                if self.channel_matching_convs[1] is not None:
                    ref_sample = self.channel_matching_convs[1](ref_sample_list[-2])
                else:
                    ref_sample = ref_sample_list[-2]
                sample = self.ref_attn_blocks[1](sample, ref_sample)

            # up
            if ref_sample_list is not None:
                    ref_sample_ind = -3
            for up_block, conv, ref_attn_block in zip(self.up_blocks, self.channel_matching_convs[2:], self.ref_attn_blocks[2:]):
                sample = up_block(sample, latent_embeds)
                # sample 0: (1,512,128,128)
                # sample 1: (1,512,256,256)
                # sample 2: (1,256,512,512)
                # sample 3: (1,128,512,512)
                if ref_sample_list is not None:
                    if conv is not None:
                        ref_sample = conv(ref_sample_list[ref_sample_ind])
                    else:
                        ref_sample = ref_sample_list[ref_sample_ind]
                    sample = ref_attn_block(sample, ref_sample)
                    ref_sample_ind = ref_sample_ind - 1

        # post-process
        if latent_embeds is None:
            sample = self.conv_norm_out(sample)
            # sample: (1,128,512,512)
        else:
            sample = self.conv_norm_out(sample, latent_embeds)
        sample = self.conv_act(sample)
        # sample: (1,128,512,512)
        sample = self.conv_out(sample)
        # sample: (1,3,512,512)

        return sample


class DetailAutoencoderKL(AutoencoderKL):
    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
        up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int] = (64,),
        layers_per_block: int = 1,
        act_fn: str = "silu",
        latent_channels: int = 4,
        norm_num_groups: int = 32,
        sample_size: int = 32,
        scaling_factor: float = 0.18215,
        force_upcast: float = True,
    ):
        super().__init__(
            in_channels,
            out_channels,
            down_block_types,
            up_block_types,
            block_out_channels,
            layers_per_block,
            act_fn,
            latent_channels,
            norm_num_groups,
            sample_size,
            scaling_factor,
            force_upcast,
        )

        self.encoder = DetailEncoder(
            in_channels=in_channels,
            out_channels=latent_channels,
            down_block_types=down_block_types,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            act_fn=act_fn,
            norm_num_groups=norm_num_groups,
            double_z=True,
        )

        self.decoder = DetailDecoder(
            in_channels=latent_channels,
            out_channels=out_channels,
            up_block_types=up_block_types,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            norm_num_groups=norm_num_groups,
            act_fn=act_fn,
        )

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (DetailEncoder, DetailDecoder, RefAttnModel)):
            module.gradient_checkpointing = value


    def extract_encoder_features(self, ref_pixel_values):
        return self.encoder.pyramid_feature_forward(ref_pixel_values)

    def _decode(self, z: torch.Tensor, ref: Tuple[torch.Tensor], return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
        if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
            assert False, "Do not support tiling."

        z = self.post_quant_conv(z)
        dec = self.decoder(z, ref)

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)


    def decode(self, z: torch.Tensor, ref: Tuple[torch.Tensor], return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:
        if self.use_slicing and z.shape[0] > 1:
            length = z.shape[0] // ref[0].shape[0]
            decoded_slices = []
            for i, z_slice in enumerate(z.split(1)):
                ref_ind = i // length
                ref_slice = [feat[ref_ind].unsqueeze(0) for feat in ref]
                decoded_slices.append(self._decode(z_slice, ref_slice).sample)
            decoded = torch.cat(decoded_slices)
        else:
            decoded = self._decode(z, ref, return_dict=False)[0]

        if not return_dict:
            return (decoded,)

        return DecoderOutput(sample=decoded)