"""AutoencoderKL implementation compatible with diffusers weights."""

# Modified from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/autoencoder.py
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn


@dataclass
class AutoencoderKLOutput:
    sample: torch.Tensor


class AutoencoderConfig:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def get(self, key, default=None):
        return self.__dict__.get(key, default)

    def __getattr__(self, name):
        return self.__dict__.get(name)


def swish(x):
    return x * torch.sigmoid(x)


class ResnetBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels=None, dropout=0.0, temb_channels=512, groups=32, eps=1e-6):
        super().__init__()
        out_channels = out_channels or in_channels
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        self.nonlinearity = swish

        if self.in_channels != self.out_channels:
            self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.conv_shortcut = None

    def forward(self, input_tensor, temb=None):
        hidden_states = input_tensor
        hidden_states = self.norm1(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.conv_shortcut is not None:
            input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = (input_tensor + hidden_states) / 1.0
        return output_tensor


class Attention(nn.Module):
    def __init__(self, in_channels, heads=1, dim_head=None, groups=32, eps=1e-6):
        super().__init__()
        self.heads = heads
        self.in_channels = in_channels
        self.group_norm = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)

        self.to_q = nn.Linear(in_channels, in_channels)
        self.to_k = nn.Linear(in_channels, in_channels)
        self.to_v = nn.Linear(in_channels, in_channels)
        self.to_out = nn.ModuleList([nn.Linear(in_channels, in_channels)])

    def forward(self, hidden_states):
        b, c, h, w = hidden_states.shape
        residual = hidden_states
        hidden_states = self.group_norm(hidden_states)
        hidden_states = hidden_states.view(b, c, -1).transpose(1, 2)  # (B, H*W, C)

        query = self.to_q(hidden_states)
        key = self.to_k(hidden_states)
        value = self.to_v(hidden_states)

        import torch.nn.functional as F

        hidden_states = F.scaled_dot_product_attention(query, key, value)

        hidden_states = self.to_out[0](hidden_states)
        hidden_states = hidden_states.transpose(1, 2).view(b, c, h, w)

        return residual + hidden_states


class Downsample2D(nn.Module):
    def __init__(self, channels, with_conv=True, out_channels=None, padding=1):
        super().__init__()
        out_channels = out_channels or channels
        self.with_conv = with_conv
        if with_conv:
            self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=2, padding=padding)

    def forward(self, hidden_states):
        if self.with_conv:
            return self.conv(hidden_states)
        else:
            return torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)


class Upsample2D(nn.Module):
    def __init__(self, channels, with_conv=True, out_channels=None):
        super().__init__()
        out_channels = out_channels or channels
        self.with_conv = with_conv
        if with_conv:
            self.conv = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, hidden_states):
        hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            hidden_states = self.conv(hidden_states)
        return hidden_states


class DownEncoderBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_downsample=True):
        super().__init__()
        resnets = []
        for i in range(num_layers):
            in_c = in_channels if i == 0 else out_channels
            resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
        self.resnets = nn.ModuleList(resnets)

        if add_downsample:
            self.downsamplers = nn.ModuleList(
                [Downsample2D(out_channels, with_conv=True, out_channels=out_channels, padding=0)]
            )
        else:
            self.downsamplers = None

    def forward(self, hidden_states):
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states)

        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                pad = (0, 1, 0, 1)
                hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
                hidden_states = downsampler(hidden_states)

        return hidden_states


class UpDecoderBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=1, resnet_eps=1e-6, resnet_groups=32, add_upsample=True):
        super().__init__()
        resnets = []
        for i in range(num_layers):
            in_c = in_channels if i == 0 else out_channels
            resnets.append(ResnetBlock2D(in_c, out_channels, eps=resnet_eps, groups=resnet_groups))
        self.resnets = nn.ModuleList(resnets)

        if add_upsample:
            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, with_conv=True, out_channels=out_channels)])
        else:
            self.upsamplers = None

    def forward(self, hidden_states):
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states)

        if self.upsamplers is not None:
            for upsampler in self.upsamplers:
                hidden_states = upsampler(hidden_states)

        return hidden_states


class UNetMidBlock2D(nn.Module):
    def __init__(self, in_channels, resnet_eps=1e-6, resnet_groups=32, attention_head_dim=None):
        super().__init__()
        self.resnets = nn.ModuleList(
            [
                ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
                ResnetBlock2D(in_channels, in_channels, eps=resnet_eps, groups=resnet_groups),
            ]
        )
        self.attentions = nn.ModuleList([Attention(in_channels, heads=1, groups=resnet_groups, eps=resnet_eps)])

    def forward(self, hidden_states):
        hidden_states = self.resnets[0](hidden_states)
        for attn in self.attentions:
            hidden_states = attn(hidden_states)
        hidden_states = self.resnets[1](hidden_states)
        return hidden_states


class Encoder(nn.Module):
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        block_out_channels=(64,),
        layers_per_block=2,
        norm_num_groups=32,
        double_z=True,
    ):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)

        self.down_blocks = nn.ModuleList([])
        output_channel = block_out_channels[0]
        for i, block_out_channel in enumerate(block_out_channels):
            input_channel = output_channel
            output_channel = block_out_channel
            is_final_block = i == len(block_out_channels) - 1

            block = DownEncoderBlock2D(
                input_channel,
                output_channel,
                num_layers=layers_per_block,
                resnet_groups=norm_num_groups,
                add_downsample=not is_final_block,
            )
            self.down_blocks.append(block)

        self.mid_block = UNetMidBlock2D(
            block_out_channels[-1],
            resnet_groups=norm_num_groups,
        )

        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
        self.conv_act = nn.SiLU()

        conv_out_channels = 2 * out_channels if double_z else out_channels
        self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)

    def forward(self, x):
        x = self.conv_in(x)
        for block in self.down_blocks:
            x = block(x)
        x = self.mid_block(x)
        x = self.conv_norm_out(x)
        x = self.conv_act(x)
        x = self.conv_out(x)
        return x


class Decoder(nn.Module):
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        block_out_channels=(64,),
        layers_per_block=2,
        norm_num_groups=32,
    ):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)

        self.mid_block = UNetMidBlock2D(
            block_out_channels[-1],
            resnet_groups=norm_num_groups,
        )

        self.up_blocks = nn.ModuleList([])
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]

        for i, block_out_channel in enumerate(reversed_block_out_channels):
            input_channel = output_channel
            output_channel = block_out_channel
            is_final_block = i == len(block_out_channels) - 1
            block = UpDecoderBlock2D(
                input_channel,
                output_channel,
                num_layers=layers_per_block + 1,
                resnet_groups=norm_num_groups,
                add_upsample=not is_final_block,
            )
            self.up_blocks.append(block)

        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv_in(x)
        x = self.mid_block(x)
        for block in self.up_blocks:
            x = block(x)
        x = self.conv_norm_out(x)
        x = self.conv_act(x)
        x = self.conv_out(x)
        return x


class AutoencoderKL(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
        up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int] = (64,),
        layers_per_block: int = 1,
        act_fn: str = "silu",
        latent_channels: int = 4,
        norm_num_groups: int = 32,
        sample_size: int = 32,
        scaling_factor: float = 0.18215,
        shift_factor: Optional[float] = None,
        force_upcast: bool = True,
        use_quant_conv: bool = True,
        use_post_quant_conv: bool = True,
        mid_block_add_attention: bool = True,
        **kwargs,
    ):
        super().__init__()
        self.config = AutoencoderConfig(
            in_channels=in_channels,
            out_channels=out_channels,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            latent_channels=latent_channels,
            scaling_factor=scaling_factor,
            shift_factor=shift_factor,
        )

        self.encoder = Encoder(
            in_channels=in_channels,
            out_channels=latent_channels,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            norm_num_groups=norm_num_groups,
            double_z=True,
        )

        self.decoder = Decoder(
            in_channels=latent_channels,
            out_channels=out_channels,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            norm_num_groups=norm_num_groups,
        )

        self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
        self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
        if self.post_quant_conv is not None:
            z = self.post_quant_conv(z)

        dec = self.decoder(z)

        if not return_dict:
            return (dec,)

        return AutoencoderKLOutput(sample=dec)
