# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

# modified from: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/simple_diffusion.py
import math
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
from typing import Optional, Tuple, Union

import numpy as np
import torchvision
import torchvision.utils
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm


class ImageHead(nn.Module):

    def __init__(self, decoder_cfg, gpt_cfg, layer_id=None):
        super().__init__()
        self.layer_id = layer_id
        cfg = (
            AttrDict(
                norm_type="layernorm",
                is_exp_norm=False,
                sequence_parallel=False,
                use_userbuffer=False,
                norm_eps=1e-5,
                norm_bias=True,
                gradient_accumulation_fusion=True,
                use_fp32_head_weight=False,
            )
            + gpt_cfg
        )
        group = PG.tensor_parallel_group()
        assert cfg.norm_type in [
            "layernorm",
            "rmsnorm",
        ], f"Norm type:{cfg.norm_type} not supported"
        if cfg.norm_type == "rmsnorm":
            self.norm = DropoutAddRMSNorm(
                cfg.n_embed,
                prenorm=False,
                eps=cfg.norm_eps,
                is_exp_norm=cfg.is_exp_norm,
                sequence_parallel=cfg.sequence_parallel,
            )
        else:
            self.norm = DropoutAddLayerNorm(
                cfg.n_embed,
                prenorm=False,
                eps=cfg.norm_eps,
                is_exp_norm=cfg.is_exp_norm,
                sequence_parallel=cfg.sequence_parallel,
                bias=cfg.norm_bias,
            )

        multiple_of = 256
        if decoder_cfg.in_channels % multiple_of != 0:
            warnings.warn(
                f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能"
            )

        dtype = default_dtype = torch.get_default_dtype()
        if cfg.use_fp32_head_weight:
            dtype = torch.float32
            print(
                "使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n",
                end="",
                flush=True,
            )
        torch.set_default_dtype(dtype)
        self.head = ColumnParallelLinear(
            cfg.n_embed,
            decoder_cfg.in_channels,
            bias=True,
            group=group,
            sequence_parallel=cfg.sequence_parallel,
            use_userbuffer=cfg.use_userbuffer,
            gradient_accumulation_fusion=cfg.gradient_accumulation_fusion,
            use_fp32_output=False,
        )
        torch.set_default_dtype(default_dtype)

        self.use_fp32_head_weight = cfg.use_fp32_head_weight

    def forward(
        self, input_args, images_split_mask: Optional[torch.BoolTensor] = None, **kwargs
    ):
        residual = None
        if isinstance(input_args, tuple):
            x, residual = input_args
        else:
            x = input_args

        x = self.norm(x, residual)

        if self.use_fp32_head_weight:
            assert (
                self.head.weight.dtype == torch.float32
            ), f"head.weight is {self.head.weight.dtype}"
            x = x.float()

        if images_split_mask is None:
            logits = self.head(x)
        else:
            bs, n_images = images_split_mask.shape[:2]
            n_embed = x.shape[-1]

            images_embed = torch.masked_select(
                x.unsqueeze(1), images_split_mask.unsqueeze(-1)
            )
            images_embed = images_embed.view((bs * n_images, -1, n_embed))
            logits = self.head(images_embed)

        return logits


class GlobalResponseNorm(nn.Module):
    # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)

        return torch.addcmul(self.bias, (self.weight * nx + 1), x, value=1)


class Downsample2D(nn.Module):
    """A 2D downsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        padding (`int`, default `1`):
            padding for the convolution.
        name (`str`, default `conv`):
            name of the downsampling 2D layer.
    """

    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        out_channels: Optional[int] = None,
        padding: int = 1,
        name: str = "conv",
        kernel_size=3,
        stride=2,
        norm_type=None,
        eps=None,
        elementwise_affine=None,
        bias=True,
    ):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.padding = padding
        self.name = name

        if norm_type == "ln_norm":
            self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
        elif norm_type == "rms_norm":
            self.norm = RMSNorm(channels, eps)
        elif norm_type is None:
            self.norm = None
        else:
            raise ValueError(f"unknown norm_type: {norm_type}")

        if use_conv:
            conv = nn.Conv2d(
                self.channels,
                self.out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=bias,
            )
        else:
            assert self.channels == self.out_channels
            conv = nn.AvgPool2d(kernel_size=stride, stride=stride)

        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
        if name == "conv":
            self.Conv2d_0 = conv
            self.conv = conv
        elif name == "Conv2d_0":
            self.conv = conv
        else:
            self.conv = conv

    def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:

        assert hidden_states.shape[1] == self.channels

        if self.norm is not None:
            hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
                0, 3, 1, 2
            )

        if self.use_conv and self.padding == 0:
            pad = (0, 1, 0, 1)
            hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)

        assert hidden_states.shape[1] == self.channels

        hidden_states = self.conv(hidden_states)

        return hidden_states


class Upsample2D(nn.Module):
    """A 2D upsampling layer with an optional convolution.

    Parameters:
        channels (`int`):
            number of channels in the inputs and outputs.
        use_conv (`bool`, default `False`):
            option to use a convolution.
        use_conv_transpose (`bool`, default `False`):
            option to use a convolution transpose.
        out_channels (`int`, optional):
            number of output channels. Defaults to `channels`.
        name (`str`, default `conv`):
            name of the upsampling 2D layer.
    """

    def __init__(
        self,
        channels: int,
        use_conv: bool = False,
        use_conv_transpose: bool = False,
        out_channels: Optional[int] = None,
        name: str = "conv",
        kernel_size: Optional[int] = None,
        padding=1,
        stride=2,
        norm_type=None,
        eps=None,
        elementwise_affine=None,
        bias=True,
        interpolate=True,
    ):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_conv_transpose = use_conv_transpose
        self.name = name
        self.interpolate = interpolate
        self.stride = stride

        if norm_type == "ln_norm":
            self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
        elif norm_type == "rms_norm":
            self.norm = RMSNorm(channels, eps)
        elif norm_type is None:
            self.norm = None
        else:
            raise ValueError(f"unknown norm_type: {norm_type}")

        conv = None
        if use_conv_transpose:
            if kernel_size is None:
                kernel_size = 4
            conv = nn.ConvTranspose2d(
                channels,
                self.out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=bias,
            )
        elif use_conv:
            if kernel_size is None:
                kernel_size = 3
            conv = nn.Conv2d(
                self.channels,
                self.out_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias,
            )

        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
        if name == "conv":
            self.conv = conv
        else:
            self.Conv2d_0 = conv

    def forward(
        self,
        hidden_states: torch.Tensor,
        output_size: Optional[int] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:

        assert hidden_states.shape[1] == self.channels

        if self.norm is not None:
            hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
                0, 3, 1, 2
            )

        if self.use_conv_transpose:
            return self.conv(hidden_states)

        # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
        # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
        # https://github.com/pytorch/pytorch/issues/86679
        dtype = hidden_states.dtype
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(torch.float32)

        # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
        if hidden_states.shape[0] >= 64:
            hidden_states = hidden_states.contiguous()

        # if `output_size` is passed we force the interpolation output
        # size and do not make use of `scale_factor=2`
        if self.interpolate:
            if output_size is None:
                hidden_states = F.interpolate(
                    hidden_states, scale_factor=self.stride, mode="nearest"
                )
            else:
                hidden_states = F.interpolate(
                    hidden_states, size=output_size, mode="nearest"
                )

        # If the input is bfloat16, we cast back to bfloat16
        if dtype == torch.bfloat16:
            hidden_states = hidden_states.to(dtype)

        # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
        if self.use_conv:
            if self.name == "conv":
                hidden_states = self.conv(hidden_states)
            else:
                hidden_states = self.Conv2d_0(hidden_states)

        return hidden_states


class ConvNextBlock(nn.Module):
    def __init__(
        self,
        channels,
        norm_eps,
        elementwise_affine,
        use_bias,
        hidden_dropout,
        hidden_size,
        res_ffn_factor: int = 4,
    ):
        super().__init__()
        self.depthwise = nn.Conv2d(
            channels,
            channels,
            kernel_size=7,
            padding=3,
            groups=channels,
            bias=use_bias,
        )
        self.norm = RMSNorm(channels, norm_eps)
        self.channelwise_linear_1 = nn.Linear(
            channels, int(channels * res_ffn_factor), bias=use_bias
        )
        self.channelwise_act = nn.GELU()
        self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
        self.channelwise_linear_2 = nn.Linear(
            int(channels * res_ffn_factor), channels, bias=use_bias
        )
        self.channelwise_dropout = nn.Dropout(hidden_dropout)
        self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)

    def forward(self, x, cond_embeds):
        x_res = x

        x = self.depthwise(x)

        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = self.channelwise_linear_1(x)
        x = self.channelwise_act(x)
        x = self.channelwise_norm(x)
        x = self.channelwise_linear_2(x)
        x = self.channelwise_dropout(x)
        x = x.permute(0, 3, 1, 2)

        x = x + x_res

        scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
        # x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
        x = torch.addcmul(
            shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1
        )

        return x


class Patchify(nn.Module):
    def __init__(
        self,
        in_channels,
        block_out_channels,
        patch_size,
        bias,
        elementwise_affine,
        eps,
        kernel_size=None,
    ):
        super().__init__()
        if kernel_size is None:
            kernel_size = patch_size
        self.patch_conv = nn.Conv2d(
            in_channels,
            block_out_channels,
            kernel_size=kernel_size,
            stride=patch_size,
            bias=bias,
        )
        self.norm = RMSNorm(block_out_channels, eps)

    def forward(self, x):
        embeddings = self.patch_conv(x)
        embeddings = embeddings.permute(0, 2, 3, 1)
        embeddings = self.norm(embeddings)
        embeddings = embeddings.permute(0, 3, 1, 2)
        return embeddings


class Unpatchify(nn.Module):
    def __init__(
        self, in_channels, out_channels, patch_size, bias, elementwise_affine, eps
    ):
        super().__init__()
        self.norm = RMSNorm(in_channels, eps)
        self.unpatch_conv = nn.Conv2d(
            in_channels,
            out_channels * patch_size * patch_size,
            kernel_size=1,
            bias=bias,
        )
        self.pixel_shuffle = nn.PixelShuffle(patch_size)
        self.patch_size = patch_size

    def forward(self, x):
        # [b, c, h, w]
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)
        x = self.unpatch_conv(x)
        x = self.pixel_shuffle(x)
        return x


class UVitBlock(nn.Module):
    def __init__(
        self,
        channels,
        out_channels,
        num_res_blocks,
        stride,
        hidden_size,
        hidden_dropout,
        elementwise_affine,
        norm_eps,
        use_bias,
        downsample: bool,
        upsample: bool,
        res_ffn_factor: int = 4,
        seq_len=None,
        concat_input=False,
        original_input_channels=None,
        use_zero=True,
        norm_type="RMS",
    ):
        super().__init__()

        self.res_blocks = nn.ModuleList()
        for i in range(num_res_blocks):
            conv_block = ConvNextBlock(
                channels,
                norm_eps,
                elementwise_affine,
                use_bias,
                hidden_dropout,
                hidden_size,
                res_ffn_factor=res_ffn_factor,
            )

            self.res_blocks.append(conv_block)

        if downsample:
            self.downsample = Downsample2D(
                channels=channels,
                out_channels=out_channels,
                use_conv=True,
                name="Conv2d_0",
                kernel_size=3,
                padding=1,
                stride=stride,
                norm_type="rms_norm",
                eps=norm_eps,
                elementwise_affine=elementwise_affine,
                bias=use_bias,
            )
        else:
            self.downsample = None

        if upsample:
            self.upsample = Upsample2D(
                channels=channels,
                out_channels=out_channels,
                use_conv_transpose=False,
                use_conv=True,
                kernel_size=3,
                padding=1,
                stride=stride,
                name="conv",
                norm_type="rms_norm",
                eps=norm_eps,
                elementwise_affine=elementwise_affine,
                bias=use_bias,
                interpolate=True,
            )
        else:
            self.upsample = None

    def forward(self, x, emb, recompute=False):
        for res_block in self.res_blocks:
            x = res_block(x, emb)

        if self.downsample is not None:
            x = self.downsample(x)

        if self.upsample is not None:
            x = self.upsample(x)

        return x


class ShallowUViTEncoder(nn.Module):
    def __init__(
        self,
        input_channels=3,
        stride=4,
        kernel_size=7,
        padding=None,
        block_out_channels=(768,),
        layers_in_middle=2,
        hidden_size=2048,
        elementwise_affine=True,
        use_bias=True,
        norm_eps=1e-6,
        dropout=0.0,
        use_mid_block=True,
        **kwargs,
    ):
        super().__init__()

        self.time_proj = Timesteps(
            block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0
        )
        self.time_embed = TimestepEmbedding(
            block_out_channels[0], hidden_size, sample_proj_bias=use_bias
        )

        if padding is None:
            padding = math.ceil(kernel_size - stride)
        self.in_conv = nn.Conv2d(
            in_channels=input_channels,
            out_channels=block_out_channels[0],
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )
        if use_mid_block:
            self.mid_block = UVitBlock(
                block_out_channels[-1],
                block_out_channels[-1],
                num_res_blocks=layers_in_middle,
                hidden_size=hidden_size,
                hidden_dropout=dropout,
                elementwise_affine=elementwise_affine,
                norm_eps=norm_eps,
                use_bias=use_bias,
                downsample=False,
                upsample=False,
                stride=1,
                res_ffn_factor=4,
            )
        else:
            self.mid_block = None

    def get_num_extra_tensors(self):
        return 2

    def forward(self, x, timesteps, return_hs=True):

        bs = x.shape[0]
        dtype = x.dtype

        t_emb = self.time_proj(timesteps.flatten()).view(bs, -1).to(dtype)
        t_emb = self.time_embed(t_emb)
        x_emb = self.in_conv(x)

        if self.mid_block is not None:
            x_emb = self.mid_block(x_emb, t_emb)

        hs = [x_emb] if return_hs else []
        return x_emb, t_emb, hs


class ShallowUViTDecoder(nn.Module):
    def __init__(
        self,
        in_channels=768,
        out_channels=3,
        block_out_channels: Tuple[int] = (768,),
        upsamples=2,
        layers_in_middle=2,
        hidden_size=2048,
        elementwise_affine=True,
        norm_eps=1e-6,
        use_bias=True,
        dropout=0.0,
        use_mid_block=True,
        **kwargs,
    ):
        super().__init__()
        if use_mid_block:
            self.mid_block = UVitBlock(
                in_channels + block_out_channels[-1],
                block_out_channels[
                    -1
                ],  # In fact, the parameter is not used because it has no effect when both downsample and upsample are set to false.
                num_res_blocks=layers_in_middle,
                hidden_size=hidden_size,
                hidden_dropout=dropout,
                elementwise_affine=elementwise_affine,
                norm_eps=norm_eps,
                use_bias=use_bias,
                downsample=False,
                upsample=False,
                stride=1,
                res_ffn_factor=4,
            )
        else:
            self.mid_block = None
        self.out_convs = nn.ModuleList()
        for rank in range(upsamples):
            if rank == upsamples - 1:
                curr_out_channels = out_channels
            else:
                curr_out_channels = block_out_channels[-1]
            if rank == 0:
                curr_in_channels = block_out_channels[-1] + in_channels
            else:
                curr_in_channels = block_out_channels[-1]
            self.out_convs.append(
                Unpatchify(
                    curr_in_channels,
                    curr_out_channels,
                    patch_size=2,
                    bias=use_bias,
                    elementwise_affine=elementwise_affine,
                    eps=norm_eps,
                )
            )
        self.input_norm = RMSNorm(in_channels, norm_eps)

    def forward(self, x, hs, t_emb):

        x = x.permute(0, 2, 3, 1)
        x = self.input_norm(x)
        x = x.permute(0, 3, 1, 2)

        # x = torch.cat([x, hs.pop()], dim=1) # this might have problem during inference # TODO: change it to if-else with self.train
        x = torch.cat([x, hs[-1]], dim=1)
        if self.mid_block is not None:
            x = self.mid_block(x, t_emb)
        for out_conv in self.out_convs:
            x = out_conv(x)
        # assert len(hs) == 0
        return x
