import math

from einops import rearrange
import torch as pt
import torch.nn as nn
import torch.nn.functional as ptnf

from .basic import TransformDecodeBlock


class SlotDiffusionImage(nn.Module):
    """SlotDiffusion model for images."""

    def __init__(
        self,
        mediat,  # VQVAE
        encode_backbone,  # resnet18
        h2w2,
        encode_posit_embed,  # CartesianPositionEmbedding
        encode_project,  # LN+MLP
        initializ,  # LearntGaussian
        correct,  # SlotAttention
        noise_sched,  # NoiseSched
        decode_backbone,  # dmUNetCondition
    ):
        super().__init__()
        mediat.eval()
        self.mediat = mediat  # type: VQVAE
        self.encode_backbone = encode_backbone
        self.h2w2 = h2w2
        self.encode_posit_embed = encode_posit_embed
        self.encode_project = encode_project
        self.initializ = initializ
        self.correct = correct
        self.noise_sched = noise_sched
        self.decode_backbone = decode_backbone
        self.reset_parameters()

    def reset_parameters(self):
        for m0 in [
            self.encode_posit_embed,
            self.encode_project,
            self.initializ,
            self.correct,
            # self.noise_sched,
        ]:
            for m in m0.modules():
                if isinstance(m, nn.Conv2d):
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Linear):
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.GRUCell):
                    if m.bias:
                        nn.init.zeros_(m.bias_ih)
                        nn.init.zeros_(m.bias_hh)

    def forward(self, input, condit=None):
        """
        input: image in shape (b,c,h,w)
        condit: condition in shape (b,n,c)
        """
        b, c, h, w = input.shape

        encode = self.encode_backbone(input)
        encode_ = encode.permute(0, 2, 3, 1)  # (b,h,w,c)
        encode_ = self.encode_posit_embed(encode_)
        encode_ = encode_.flatten(1, 2)  # (b,h*w,c)
        encode_ = self.encode_project(encode_)

        hidden = self.initializ(b if condit is None else condit)  # (b,n,c)
        correct, attent_ = self.correct(encode_, hidden)
        attent = rearrange(attent_, "b n (h w) -> b n h w", h=self.h2w2[0])

        encode1, zidx, quant, decode1 = self.mediat(input)
        quant = quant.detach()

        noise, timestep, noisy = self.noise_sched(quant)
        # noisy = noisy.to(quant.dtype)
        # noisy /= noisy.std([1, 2, 3], keepdim=True)  # really beneficial
        decode = self.decode_backbone(noisy, timestep, correct)

        segment = attent.argmax(1)  # (b,h,w)
        return zidx, noise, decode, segment, correct, attent


class SlotDiffusionVideo(nn.Module):
    """SlotDiffusion model for videos."""

    def __init__(
        self,
        mediat,  # VQVAE
        encode_backbone,  # resnet18
        h2w2,
        encode_posit_embed,  # CartesianPositionEmbedding
        encode_project,  # LN+MLP
        initializ,  # LearntGaussian
        correct,  # SlotAttention
        predict,
        noise_sched,  # NoiseSched
        decode_backbone,  # dmUNetCondition
    ):
        super().__init__()
        self.mediat = mediat  # type: VQVAE
        self.encode_backbone = encode_backbone
        self.h2w2 = h2w2
        self.encode_posit_embed = encode_posit_embed
        self.encode_project = encode_project
        self.initializ = initializ
        self.correct = correct
        self.predict = predict
        self.noise_sched = noise_sched
        self.decode_backbone = decode_backbone
        self.reset_parameters()

    def reset_parameters(self):
        for m0 in [
            self.encode_posit_embed,
            self.encode_project,
            self.initializ,
            self.correct,
            # self.noise_sched,
        ]:
            for m in m0.modules():
                if isinstance(m, nn.Conv2d):
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.Linear):
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                elif isinstance(m, nn.GRUCell):
                    if m.bias:
                        nn.init.zeros_(m.bias_ih)
                        nn.init.zeros_(m.bias_hh)

    def forward(self, input, condit=None):
        """
        input: image in shape (b,t,c,h,w)
        condit: condition in shape (b,t,n,c)
        """
        b, t, c, h, w = input.shape
        input = input.flatten(0, 1)  # (b*t,c,h,w)

        encode = self.encode_backbone(input)
        encode_ = encode.permute(0, 2, 3, 1)  # (b,h,w,c)
        encode_ = self.encode_posit_embed(encode_)
        encode_ = encode_.flatten(1, 2)  # (b,h*w,c)
        encode_ = self.encode_project(encode_)

        encode_ = rearrange(encode_, "(b t) hw c -> b t hw c", b=b)

        hidden = self.initializ(b if condit is None else condit[:, 0, :, :])  # (b,n,c)
        correct = []
        attent_ = []
        for i in range(t):
            correct_t, attent_t = self.correct(encode_[:, i, :, :], hidden)
            hidden = self.predict(correct_t)
            correct.append(correct_t)
            attent_.append(attent_t)
        correct = pt.stack(correct, 1)  # (b,t,n,c)
        attent_ = pt.stack(attent_, 1)  # (b,t,n,h*w)
        attent = rearrange(attent_, "b t n (h w) -> b t n h w", h=self.h2w2[0])

        encode1, zidx, quant, decode1 = self.mediat(input)
        zidx = zidx.unflatten(0, [b, t])
        quant = quant.detach()

        noise, timestep, noisy = self.noise_sched(quant)
        # noisy = noisy.to(quant.dtype)
        # noisy /= noisy.std([1, 2, 3], keepdim=True)  # really beneficial
        decode = self.decode_backbone(noisy, timestep, correct.flatten(0, 1))
        noise = rearrange(noise, "(b t) c h w -> b t c h w", t=t)
        decode = rearrange(decode, "(b t) c h w -> b t c h w", t=t)

        segment = attent.argmax(2)  # (b,t,h,w)
        return zidx, noise, decode, segment, correct, attent


# class NoiseSched(nn.Module):
#     """``diffusers.schedulers.DDPMScheduler``"""

#     def __init__(  # 0.0015 0.0195
#         self, num_timestep=1000, beta_start=0.0001, beta_end=0.02, beta_sched="linear"
#     ):
#         super().__init__()
#         if beta_sched == "linear":  # scaled-linear
#             # beta = pt.linspace(beta_start, beta_end, num_timestep)  # XXX
#             beta = (
#                 pt.linspace(
#                     beta_start**0.5, beta_end**0.5, num_timestep, dtype=pt.float64
#                 )
#                 ** 2
#             )
#         else:  # new scheds here
#             raise "ValueError"
#         # self.register_buffer("a", alpha.cumprod(0).sqrt())
#         self.register_buffer("a", (1 - beta).cumprod(0).sqrt())  # TODO fp64 ???
#         # self.register_buffer("b", 1 - self.a)
#         self.register_buffer("b", (1 - (1 - beta).cumprod(0)).sqrt())
#         self.num_timestep = num_timestep

#     @pt.no_grad()
#     def forward(self, input):
#         """
#         input: image in shape (b,c,h,w)
#         """
#         b, c, h, w = input.shape
#         noise = pt.randn_like(input)
#         timestep = pt.randint(
#             0, self.num_timestep, [b], dtype=pt.long, device=input.device
#         )
#         a = self.a[timestep][:, None, None, None]
#         b = self.b[timestep][:, None, None, None]
#         noisy = a * input + b * noise
#         return noise, timestep, noisy


# class NoiseSched(nn.Module):

#     def __init__(self, *args, **kwargs) -> None:
#         super().__init__()
#         from diffusers import DDPMScheduler

#         scheduler_config = "./lsdcfg-movi-e/scheduler/scheduler_config.json"
#         noise_scheduler_config = DDPMScheduler.load_config(scheduler_config)
#         self.noise_scheduler = DDPMScheduler.from_config(noise_scheduler_config)

#     def forward(self, input):
#         b = input.size(0)
#         timesteps = pt.randint(
#             0,
#             self.noise_scheduler.config.num_train_timesteps,
#             (b,),
#             device=input.device,
#         )
#         timesteps = timesteps.long()
#         noise = pt.randn_like(input)
#         noisy_model_input = self.noise_scheduler.add_noise(input, noise, timesteps)
#         return noise, timesteps, noisy_model_input


# class SlotDiffuzUNetDecode(nn.Module):

#     def __init__(self):
#         super().__init__()
#         import sys

#         sys.path.append(
#             "/media/GeneralZ/Storage/Active/ocl/active-code/SlotDiffusion-Wuziyi616"
#         )
#         from slotdiffusion.img_based.models.ddpm.ddpm import DDPM

#         latent_ch = 4
#         slot_size = 256  # TODO XXX official: 192
#         unet_dict = dict(
#             in_channels=latent_ch,  # latent feature in  # TODO XXX ``// 4``
#             model_channels=128,  # >=64 can eliminate the `color-bias` problem
#             out_channels=latent_ch,  # latent feature noise out  # TODO XXX ``// 4``
#             num_res_blocks=2,  # 2  # XXX TODO 1 fails with zc=4
#             attention_resolutions=(8, 4, 2),  # actually the downsampling factor
#             dropout=0.1,
#             channel_mult=(1, 2, 3, 4),
#             dims=2,  # 2D data
#             use_checkpoint=False,  # LDM saves 4x memory
#             num_head_channels=32,
#             resblock_updown=False,  # usually False
#             conv_resample=True,  # up/downsample followed by Conv
#             transformer_depth=1,
#             context_dim=slot_size,  # condition on slots
#             n_embed=None,  # VQ codebook support for LDM
#         )
#         resolution = (128, 128)
#         self.ddpm = DDPM(
#             resolution=tuple(res // 4 for res in resolution),
#             unet_dict=unet_dict,
#             use_ema=False,
#             diffusion_dict=dict(
#                 pred_target="eps",  # 'eps' or 'x0', predict noise or direct x0
#                 z_scale_factor=1.0,  # 1.05
#                 timesteps=1000,  # XXX 1000 700 1500
#                 beta_schedule="linear",
#                 # the one used in LDM
#                 linear_start=0.0015,
#                 linear_end=0.0195,  # XXX 0.0195
#                 cosine_s=8e-3,  # doesn't matter for linear schedule
#                 log_every_t=200,  # log every t steps in denoising sampling
#                 logvar_init=0.0,
#             ),
#             conditioning_key="crossattn",
#         )
#         self.ddpm.clip_denoised = False  # latent features are unbounded values
#         self.ddpm.vq_denoised = True  # LDM uses this by default
#         print(self.ddpm.num_timesteps)

#     def forward(self, x_noisy, t, context):
#         return self.ddpm.forward(x_noisy, t, context=context)


# class UNet2dCondition(nn.Module):

#     def __init__(self, *args, **kwargs) -> None:
#         super().__init__()
#         from diffusers import (
#             UNet2DConditionModel,
#             DDPMPipeline,
#             StableDiffusionPipeline,
#             DiffusionPipeline,
#         )
#         from diffusers.models.unet_2d_blocks import (
#             CrossAttnDownBlock2D,
#             CrossAttnUpBlock2D,
#             DownBlock2D,
#             UpBlock2D,
#             UNetMidBlock2DCrossAttn,
#         )

#         import sys

#         sys.path.append(
#             "/media/GeneralZ/Storage/Active/ocl/active-code/latent-slot-diffusion-JindongJiang"
#         )
#         from src.models.unet_with_pos import UNet2DConditionModelWithPos

#         unet_config = "./lsdcfg-movi-e/unet/config.json"
#         unet_config = UNet2DConditionModelWithPos.load_config(unet_config)
#         unet = UNet2DConditionModelWithPos.from_config(unet_config)
#         self.module = unet

#         # model_id = "runwayml/stable-diffusion-v1-5"
#         # pipeline = StableDiffusionPipeline.from_pretrained(
#         #     model_id,
#         #     in_channels=4,
#         #     out_channels=4,
#         #     down_block_types=[
#         #         "CrossAttnDownBlock2D",
#         #         "CrossAttnDownBlock2D",
#         #         "CrossAttnDownBlock2D",
#         #         "DownBlock2D",
#         #     ],
#         #     up_block_types=[
#         #         "UpBlock2D",
#         #         "CrossAttnUpBlock2D",
#         #         "CrossAttnUpBlock2D",
#         #         "CrossAttnUpBlock2D",
#         #     ],
#         #     block_out_channels=[128, 256, 512, 512],
#         #     cross_attention_dim=[256] * 4,
#         # )
#         # self.module = pipeline.unet
#         # print([(k, v.requires_grad) for k, v in self.module.named_parameters()])

#         # self.module = UNet2DConditionModel(  # good arifg=20 in 10 epoch and then bad
#         #     in_channels=4,
#         #     out_channels=4,
#         #     down_block_types=[
#         #         "DownBlock2D",  # CrossAttnDownBlock2D
#         #         "CrossAttnDownBlock2D",
#         #         "CrossAttnDownBlock2D",
#         #         "CrossAttnDownBlock2D",  # DownBlock2D
#         #     ],
#         #     mid_block_type="UNetMidBlock2DCrossAttn",
#         #     up_block_types=[
#         #         "CrossAttnUpBlock2D",  # UpBlock2D
#         #         "CrossAttnUpBlock2D",
#         #         "CrossAttnUpBlock2D",
#         #         "UpBlock2D",  # CrossAttnUpBlock2D
#         #     ],
#         #     block_out_channels=[128, 256, 384, 512],  # 128, 256, 512, 512
#         #     cross_attention_dim=[256] * 4,
#         #     dropout=0.1,
#         #     layers_per_block=2,
#         #     transformer_layers_per_block=1,
#         # )
#         # # print(self.module)
#         # nn.init.zeros_(self.module.conv_out.weight)
#         # nn.init.zeros_(self.module.conv_out.bias)

#         # path = "segmind/portrait-finetuned"
#         # pipeline = DiffusionPipeline.from_pretrained(
#         #     path, safety_checker=None, requires_safety_checker=False
#         # )
#         # self.project = nn.Linear(256, 768)
#         # self.module = pipeline.unet
#         # self.module.enable_gradient_checkpointing()

#     def forward(self, input, temb, context):
#         # context = self.project(context)
#         output = self.module(input, temb, context)
#         return output.sample


class dmUNetCondition(nn.Module):
    """The full UNet model with attention and timestep embedding."""

    def __init__(
        self,
        in_dim,  # 3
        base_dim,  # 128
        temb_dim,
        cond_dim,
        out_dim=None,  # 3
        dropout=0.1,
        num_group=32,
        head_dim=32,
        num_stage=3,
        num_block=1,
        num_layer=1,
    ):
        super().__init__()
        self.base_dim = base_dim

        self.time_embed = nn.Sequential(
            nn.Linear(base_dim, temb_dim),
            nn.GELU(),
            nn.Linear(temb_dim, temb_dim),
        )

        self.input = nn.Conv2d(in_dim, base_dim, 3, padding=1)

        self.downsample = nn.ModuleList()
        c = base_dim
        for i in range(num_stage - 1):
            layers = []
            for _ in range(num_block):
                layers.append(
                    dmResnetBlock2d(c, temb_dim, dropout, num_group=num_group)
                )
                if i > 0:
                    layers.append(
                        dmTransformDecode2d(
                            c,
                            num_layer,
                            c // head_dim,  # more
                            c * 4,
                            dropout,
                            True,
                            num_group,
                            kv_dim=cond_dim,
                        )
                    )
            layers.append(nn.Conv2d(c, c * 2, 3, stride=2, padding=1))
            c = c * 2
            self.downsample.append(dmUNetDown(*layers))

        self.middle = dmUNetMiddle(
            dmResnetBlock2d(c, temb_dim, dropout, num_group=num_group),
            dmTransformDecode2d(
                c,
                num_layer,
                c // head_dim,
                c * 4,
                dropout,
                True,
                num_group,
                kv_dim=cond_dim,
            ),
            dmResnetBlock2d(c, temb_dim, dropout, num_group=num_group),
        )

        self.upsample = nn.ModuleList()
        for j in range(num_stage - 1)[::-1]:
            layers = [
                nn.UpsamplingNearest2d(scale_factor=2),
                nn.Conv2d(c, c // 2, 3, padding=1),
            ]
            c = c // 2
            for _ in range(num_block):
                layers.append(
                    dmResnetBlock2d(
                        c * 2 if _ == 0 else c, temb_dim, dropout, c, num_group
                    )
                )
                if j > 0:
                    layers.append(
                        dmTransformDecode2d(
                            c,
                            num_layer,
                            c // head_dim,
                            c * 4,
                            dropout,
                            True,
                            num_group,
                            kv_dim=cond_dim,
                        )
                    )
            self.upsample.append(dmUNetUp(*layers))

        assert c == base_dim
        out_dim = out_dim or in_dim
        self.output = nn.Conv2d(base_dim, out_dim, 3, padding=1)

    def forward(self, input, timestep, condition):
        """
        input: (b,c,h,w)
        timestep: (b,)
        context: (b,n,c), conditioning plugged in via crossattn
        """
        temb = __class__.create_sinusoidal_timestep_embedding(timestep, self.base_dim)
        temb = self.time_embed(temb)

        x = self.input(input)

        hxs = []
        for dstage in self.downsample:
            hx, x = dstage(x, temb, condition)
            hxs.append(hx)

        x = self.middle(x, temb, condition)

        for ustage in self.upsample:
            hx = hxs.pop()
            x = ustage(x, hx, temb, condition)

        output = self.output(x)
        return output

    @staticmethod
    def create_sinusoidal_timestep_embedding(
        timestep, embed_dim, max_period=10000, downscale_freq_shift=1, scale=1
    ):
        """Denoising Diffusion Probabilistic Models.

        timestep: in shape (b,); may be fractional
        embed_dim: the dimension of the output
        max_period: controls the minimum frequency of the embeddings
        return: in shape (b,c), positional embedding tensor
        """
        assert len(timestep.shape) == 1
        half_dim = embed_dim // 2
        exponent = -math.log(max_period) * pt.arange(
            half_dim, dtype=pt.float32, device=timestep.device
        )
        exponent = exponent / (half_dim - downscale_freq_shift)
        temb = timestep.float()[:, None] * exponent.exp()[None, :]
        temb = scale * temb
        temb = pt.cat([temb.sin(), temb.cos()], dim=1)
        if embed_dim % 2 == 1:
            temb = ptnf.pad(temb, (0, 1, 0, 0))
        return temb


class dmUNetDown(nn.Sequential):
    """resnet-attent-resnet-attent...-downsample"""

    def forward(self, input, te, condition):
        x = input
        for i, layer in enumerate(self):
            if isinstance(layer, dmResnetBlock2d):
                x = layer(x, te)
            elif isinstance(layer, dmTransformDecode2d):
                x = layer(x, condition)
            else:
                break
        y = x
        for layer in list(self)[i:]:  # downsample
            if not isinstance(layer, (dmResnetBlock2d, dmTransformDecode2d)):
                y = layer(y)
            else:
                raise "ValueError"
        return x, y


class dmUNetMiddle(nn.Sequential):
    """resnet-attent-resnet-attent..."""

    def forward(self, input, te, condition):
        x = input
        for layer in self:
            if isinstance(layer, dmResnetBlock2d):
                x = layer(x, te)
            elif isinstance(layer, dmTransformDecode2d):
                x = layer(x, condition)
            else:
                raise "ValueError"
        return x


class dmUNetUp(nn.Sequential):
    """upsample-resnet-attent-resnet-attent..."""

    def forward(self, input, last, te, condition):
        x = input
        for i, layer in enumerate(self):  # upsample
            if not isinstance(layer, (dmResnetBlock2d, dmTransformDecode2d)):
                x = layer(x)
            else:
                break
        y = pt.cat([x, last], dim=1)
        for layer in list(self)[i:]:
            if isinstance(layer, dmResnetBlock2d):
                y = layer(y, te)
            elif isinstance(layer, dmTransformDecode2d):
                y = layer(y, condition)
            else:
                raise "ValueError"
        return y


class dmResnetBlock2d(nn.Module):
    """``diffusers.models.resnet.ResnetBlock2D``"""

    def __init__(
        self, in_dim, te_dim, dropout=0.0, out_dim=None, num_group=32, downup=0
    ):
        super().__init__()
        out_dim = out_dim or in_dim

        self.norm1 = nn.GroupNorm(num_group, in_dim)
        self.act1 = nn.SiLU(True)
        if downup == 1:  # 0,1,2
            self.downup0 = nn.UpsamplingNearest2d(scale_factor=2)
        elif downup == 2:
            self.downup0 = nn.AvgPool2d(2)
        else:
            self.downup0 = lambda _: _
        self.downup = self.downup0
        self.conv1 = nn.Conv2d(in_dim, out_dim, 3, stride=1, padding=1)

        self.proj_t = nn.Linear(te_dim, out_dim)

        self.norm2 = nn.GroupNorm(num_group, out_dim)
        self.act2 = nn.SiLU(True)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1)

        if in_dim != out_dim:
            self.skip = nn.Conv2d(in_dim, out_dim, 1, stride=1, padding=0)
        else:
            self.skip = lambda _: _

    def forward(self, input, te):
        """
        input: in shape (b,c,h,w)
        te: in shape (b,n) ???
        """
        z = self.norm1(input)
        z = self.act1(z)
        z = self.downup(z)
        z = self.conv1(z)

        z = z + self.proj_t(te)[:, :, None, None]

        z = self.norm2(z)
        z = self.act2(z)
        z = self.dropout(z)
        z = self.conv2(z)

        x = self.downup0(input)
        x = self.skip(x)
        return x + z


class dmTransformDecode2d(nn.Module):
    """``diffusers.models.transformers.Transformer2DModel``"""

    def __init__(
        self,
        embed_dim,
        num_layer,
        num_head,
        ffn_dim,
        dropout=0,
        pre_norm=True,
        num_group=32,
        kv_dim=None,
    ):
        super().__init__()
        self.norm0 = nn.GroupNorm(num_group, embed_dim)
        self.proj0 = nn.Conv2d(embed_dim, embed_dim, 1, stride=1, padding=0)  # TODO rm
        self.layers = nn.ModuleList(  # TODO set first norm to Identity like ``TfdOCL``
            [
                TransformDecodeBlock(
                    embed_dim, num_head, ffn_dim, dropout, pre_norm, kv_dim=kv_dim
                )
                for _ in range(num_layer)
            ]
        )
        self.proj9 = nn.Conv2d(embed_dim, embed_dim, 1, stride=1, padding=0)  # TODO rm

    def forward(self, input, condition):
        """
        input: in shape (b,c,h,w)
        condition: in shape (b,n,c)
        """
        b, c, h, w = input.shape
        z = self.norm0(input)
        z = self.proj0(z)
        z = rearrange(z, "b c h w -> b (h w) c")

        for layer in self.layers:
            z = layer(z, condition)

        z = rearrange(z, "b (h w) c -> b c h w", h=h)
        z = self.proj9(z)
        return z + input
