import math
from functools import partial
from typing import Mapping, Optional, Union

import torch
import torch.nn as nn

from dae.utils.torch_utils import ACTIVATIONS
from dae.utils.train_utils import init_weights

from ...registers import DECODERS
from ..blocks.conv_res import DownBlock2D, ResnetBlock2D, UpBlock2D
from ..blocks.embeddings import LearnedPositionalEmbedding, TimestepEmbedding, Timesteps
from ..blocks.transformer_blocks import TransformerBlock, VisionTransformer


@DECODERS.register("uvit", maker="make")
class UViTDecoder(nn.Module):
    SIZES = {
        "XS": {"channels": 32, "num_attention_heads": 4, "mid_nlayers": 8, "ch_mult": (1, 2, 3, 3)},
        "S": {"channels": 48, "num_attention_heads": 4, "mid_nlayers": 8, "ch_mult": (1, 2, 3, 3)},
        "B": {"channels": 64, "mid_nlayers": 10, "ch_mult": (1, 2, 3, 3)},
        "M": {"channels": 96, "mid_nlayers": 12, "ch_mult": (1, 2, 3, 3)},
        "L": {"channels": 96, "mid_nlayers": 16, "ch_mult": (1, 2, 4, 4)},
        "XL": {"channels": 128, "mid_nlayers": 16, "ch_mult": (1, 2, 4, 4)},
        "H": {"channels": 192, "mid_nlayers": 16, "ch_mult": (1, 2, 4, 4)},
    }

    # Based on code from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py
    def __init__(
        self,
        in_channels=3,
        z_dim=4,
        channels=128,
        ch_mult=(1, 2, 4, 4),
        act_fn: str = "silu",
        vit_act_fn: str = "geglu",
        layers_per_block=2,
        num_attention_heads: Optional[int] = None,
        dropout=0.0,
        norm_num_groups=32,
        time_scale_shift=True,
        mid_nlayers=12,
        mid_theta=100.0,
        attn_window=8,
        eps=1e-5,
        ada_norm=True,
        learned_pos_embed=False,
        image_size=None,
        relative_pos_embed=True,
        rope_pos_embed=False,
        init: Optional[Mapping] = None,
    ):
        ### Config ###
        self.out_dim = in_channels
        self.ada_norm = ada_norm

        # Compute appropriate number of channels for each level, adjust for GroupNorm
        self.ch_level = [math.ceil(channels * ch_f / norm_num_groups) * norm_num_groups for ch_f in ch_mult]
        channels = self.ch_level[0]  # The first channel is the input channel
        self.mid_dim = self.ch_level[-1]

        if isinstance(dropout, (int, float)):
            dropout = [dropout] * len(self.ch_level)

        if isinstance(layers_per_block, int):
            layers_per_block = [layers_per_block] * len(self.ch_level)

        super().__init__()

        ### Input ###
        self.conv_in = nn.Conv2d(in_channels + z_dim, channels, kernel_size=3, padding=1)

        ### Time ###
        time_embed_dim = channels * 4
        self.time_proj = Timesteps(channels, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.time_embedding = TimestepEmbedding(channels, time_embed_dim, act_fn=act_fn)

        ### AdaNorm Embedding ###
        if ada_norm:
            self.ada_ctx_proj = torch.nn.Sequential(
                torch.nn.Conv2d(z_dim, channels, kernel_size=3, stride=1, padding=1),
                torch.nn.SiLU(),
                torch.nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            )

        ### Down blocks ###
        self.down_blocks = nn.ModuleList([])
        output_channel = channels
        for i_level, ch in enumerate(self.ch_level):
            input_channel = output_channel
            output_channel = ch
            is_final_block = i_level == len(self.ch_level) - 1

            if layers_per_block[i_level] != 0:
                self.down_blocks.append(
                    DownBlock2D(
                        num_layers=layers_per_block[i_level],
                        in_channels=input_channel,
                        out_channels=output_channel,
                        temb_channels=time_embed_dim,
                        dropout=dropout[i_level],
                        add_downsample=not is_final_block,
                        resnet_act_fn=act_fn,
                        resnet_groups=norm_num_groups,
                        time_scale_shift=time_scale_shift,
                        resnet_eps=eps,
                        ada_norm=ada_norm,
                        ada_emb_dim=channels,
                    )
                )

        # Mid block ###
        down_scale = (2 ** (len(self.ch_level) - 1),)
        self.mid_block = UViTMiddleTransformer(
            inner_dim=output_channel,
            dropout=dropout[-1],
            num_layers=mid_nlayers,
            norm_num_groups=norm_num_groups,
            num_attention_heads=num_attention_heads,
            rope_theta=mid_theta,
            attn_window=attn_window,
            eps=eps,
            ada_norm=ada_norm,
            ada_emb_dim=channels,
            learned_pos_embed=learned_pos_embed,
            sample_size=(image_size[0] // down_scale[0], image_size[1] // down_scale[0]) if learned_pos_embed else None,
            relative_pos_embed=relative_pos_embed,
            rope_pos_embed=rope_pos_embed,
            act_fn=vit_act_fn,
        )

        ### Up blocks ###
        self.up_blocks = nn.ModuleList([])

        for i_level, ch in enumerate(reversed(self.ch_level)):
            input_channel = self.ch_level[-i_level - 2] if i_level < len(self.ch_level) - 1 else self.ch_level[0]
            prev_output_channel = output_channel
            output_channel = ch

            is_final_block = i_level == len(self.ch_level) - 1

            if layers_per_block[-i_level - 1] != 0:
                self.up_blocks.append(
                    UpBlock2D(
                        num_layers=layers_per_block[-i_level - 1] + 1,
                        in_channels=input_channel,
                        out_channels=output_channel,
                        prev_output_channel=prev_output_channel,
                        temb_channels=time_embed_dim,
                        resolution_idx=i_level,
                        dropout=dropout[-i_level - 1],
                        add_upsample=(not is_final_block),
                        resnet_act_fn=act_fn,
                        resnet_groups=norm_num_groups,
                        time_scale_shift=time_scale_shift,
                        resnet_eps=eps,
                        ada_norm=ada_norm,
                        ada_emb_dim=channels,
                    )
                )

        ### Output ###
        self.conv_norm_out = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps)
        self.conv_out_act = ACTIVATIONS[act_fn]()
        self.conv_out = nn.Conv2d(channels, in_channels, kernel_size=3, padding=1)

        ### Weights init ###
        self.init_weights(**(init or {}))

    @classmethod
    def make(cls, size=None, **kwargs):
        if size is not None:
            if size in cls.SIZES:
                kwargs = {**cls.SIZES[size], **kwargs}
            else:
                raise ValueError(f"Unknown size '{size}' for UViTDecoder. Available sizes: {list(cls.SIZES.keys())}")
        return cls(**kwargs)

    def init_weights(self, method="xavier_uniform", gating=True, ckpt_module="decoder", **kwargs):
        for m in self.modules():
            if isinstance(m, ResnetBlock2D):
                m.conv2._init_zero = True
            elif isinstance(m, TransformerBlock):
                m.ff.proj_out._init_zero = True

        init_weights(self, method=method, gating=gating, ckpt_module=ckpt_module, **kwargs)

    def get_time_embed(self, sample: torch.Tensor, timestep: Union[torch.Tensor, float]) -> torch.Tensor:
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.float64, device=sample.device)
        elif len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        timesteps = timesteps.expand(sample.shape[0])
        t_emb = self.time_proj(timesteps).to(dtype=sample.dtype)
        return t_emb

    def forward(self, x, t, z):
        # t=0.0 -> no noise ; t=1.0 -> full noise
        # t should be mapped to [1; 1000] with 1 = image, 1000 = full noise

        ### Prepare input ###

        # Concat with z and project
        z_expanded = torch.nn.functional.interpolate(z, size=(x.shape[-2], x.shape[-1]), mode="nearest")
        x = torch.cat([x, z_expanded], dim=1)
        x = self.conv_in(x)

        ctx_emb = None
        if self.ada_norm:
            ctx_emb = self.ada_ctx_proj(z)

        ### Forward pass ###

        # 1. Time embedding
        t_emb = self.get_time_embed(sample=x, timestep=t)
        t_emb = self.time_embedding(t_emb, None)

        # 2. Down blocks
        down_block_res = [x]
        for downsample_block in self.down_blocks:
            x, res_samples = downsample_block(hidden_states=x, temb=t_emb, ctx_emb=ctx_emb)
            down_block_res.extend(res_samples)

        # 4. Mid block
        x = self.mid_block(x, ctx_emb=ctx_emb)

        # 5. Up blocks
        for upsample_block in self.up_blocks:
            res_samples = down_block_res[-len(upsample_block.resnets) :]
            down_block_res = down_block_res[: -len(upsample_block.resnets)]

            x = upsample_block(hidden_states=x, temb=t_emb, res_hidden_states_tuple=res_samples, ctx_emb=ctx_emb)

        # 6. Output
        if self.conv_norm_out:
            x = self.conv_norm_out(x)
        x = self.conv_out_act(x)
        x = self.conv_out(x)

        return x


class UViTMiddleTransformer(VisionTransformer):
    def __init__(
        self,
        *args,
        sample_size=None,
        act_fn: str = "geglu",
        learned_pos_embed=True,
        norm_num_groups: int = 32,
        out_norm: bool = False,
        **kwargs,
    ):
        super().__init__(
            *args,
            learned_pos_embed=False,
            sample_size=sample_size,
            act_fn=act_fn,
            out_norm=out_norm,
            **kwargs,
        )

        self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True)

        self.pre_pos_embeddings = None
        if learned_pos_embed:
            self.pre_pos_embeddings = LearnedPositionalEmbedding((self.inner_dim, *sample_size))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        ctx_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        residual = hidden_states
        if self.pre_pos_embeddings is not None:
            hidden_states = self.pre_pos_embeddings(hidden_states)

        hidden_states = self.norm(hidden_states)

        hidden_states = super().forward(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            ctx_emb=ctx_emb,
            out_2d_map=True,
        )
        output = hidden_states + residual
        return output


DECODERS.register("uvit_XS", module=partial(UViTDecoder.make, size="XS"))
DECODERS.register("uvit_S", module=partial(UViTDecoder.make, size="S"))
DECODERS.register("uvit_B", module=partial(UViTDecoder.make, size="B"))
DECODERS.register("uvit_M", module=partial(UViTDecoder.make, size="M"))
DECODERS.register("uvit_L", module=partial(UViTDecoder.make, size="L"))
DECODERS.register("uvit_XL", module=partial(UViTDecoder.make, size="XL"))
DECODERS.register("uvit_H", module=partial(UViTDecoder.make, size="H"))
