import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import Mlp
import torch.nn.functional as F
from einops import repeat
from .dit.rotary import apply_rotary_pos_emb


def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param t: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an (N, D) Tensor of positional embeddings.
    """
    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half, dtype=torch.float32)
            / half
        ).to(device=t.device)  # size: [dim/2], 一个指数衰减的曲线
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
    else:
        embedding = repeat(t, "b -> b d", d=dim)
    return embedding


def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """

    def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
        super().__init__()
        if out_size is None:
            out_size = hidden_size
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, out_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    def forward(self, t):
        t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
            self.mlp[0].weight.dtype
        )
        t_emb = self.mlp(t_freq)
        return t_emb


class RMSNorm(nn.Module):
    def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Apply the RMSNorm normalization to the input tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.

        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass through the RMSNorm layer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.

        """
        output = self._norm(x.float()).type_as(x)
        if hasattr(self, "weight"):
            output = output * self.weight
        return output


class FP32_Layernorm(nn.LayerNorm):
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        origin_dtype = inputs.dtype
        return F.layer_norm(
            inputs.float(),
            self.normalized_shape,
            self.weight.float(),
            self.bias.float(),
            self.eps,
        ).to(origin_dtype)


class FP32_SiLU(nn.SiLU):
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)


class FinalLayer(nn.Module):
    """
    The final layer of HunYuanDiT.
    """

    def __init__(
        self, out_feats, final_hidden_size, c_emb_size, elementwise_affine=False
    ):
        super().__init__()
        self.norm_final = nn.LayerNorm(
            final_hidden_size, elementwise_affine=elementwise_affine, eps=1e-6
        )
        self.linear = nn.Linear(final_hidden_size, out_feats, bias=True)
        self.adaLN_modulation = nn.Sequential(
            FP32_SiLU(), nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class AttentionPool(nn.Module):
    def __init__(
        self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
    ):
        super().__init__()
        self.positional_embedding = nn.Parameter(
            torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5
        )
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.permute(1, 0, 2)  # NLC -> LNC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (L+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (L+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x[:1],
            key=x,
            value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat(
                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
            ),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False,
        )
        return x.squeeze(0)


class CrossAttention(nn.Module):
    """
    Use QK Normalization.
    """

    def __init__(
        self,
        qdim,
        kdim,
        num_heads,
        qkv_bias=True,
        qk_norm=False,
        attn_drop=0.0,
        proj_drop=0.0,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.qdim = qdim
        self.kdim = kdim
        self.num_heads = num_heads
        assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
        self.head_dim = self.qdim // num_heads
        assert (
            self.head_dim % 8 == 0 and self.head_dim <= 128
        ), "Only support head_dim <= 128 and divisible by 8"
        self.scale = self.head_dim**-0.5

        self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias)
        self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias)

        # TODO: eps should be 1 / 65530 if using fp16
        self.q_norm = (
            norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
            if qk_norm
            else nn.Identity()
        )
        self.k_norm = (
            norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
            if qk_norm
            else nn.Identity()
        )
        self.attn_drop = nn.Dropout(attn_drop)
        self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, y, freqs_cis=None, q_mask=None, k_mask=None):
        """
        Parameters
        ----------
        x: torch.Tensor
            (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
        y: torch.Tensor
            (batch, seqlen2, hidden_dim2)
        freqs_cis: torch.Tensor
            (batch, hidden_dim // 2), RoPE for image
        """
        b, s1, c = x.shape  # [b, s1, D]

        if y.dim() == 2:
            y = y.unsqueeze(1).expand(-1, s1, -1)  # [b, s1, D]

        _, s2, c = y.shape  # [b, s2, 1024]

        q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim)  # [b, s1, h, d]
        kv = self.kv_proj(y).view(
            b, s2, 2, self.num_heads, self.head_dim
        )  # [b, s2, 2, h, d]
        k, v = kv.unbind(dim=2)  # [b, s, h, d]
        q = self.q_norm(q)
        k = self.k_norm(k)

        # Apply RoPE if needed
        if freqs_cis is not None:
            qq, _ = apply_rotary_pos_emb(q, None, freqs_cis, unsqueeze_dim=2)
            assert qq.shape == q.shape, f"qq: {qq.shape}, q: {q.shape}"
            q = qq

        q = q * self.scale
        q = q.transpose(-2, -3).contiguous()  # q ->  B, L1, H, C - B, H, L1, C
        k = k.permute(0, 2, 3, 1).contiguous()  # k ->  B, L2, H, C - B, H, C, L2
        attn = q @ k  # attn -> B, H, L1, L2

        if q_mask is not None:
            assert q_mask.shape == (
                b,
                s1,
            ), f"expecting key_padding_mask shape of {(b, s1)}, but got {q_mask.shape}"
            q_mask = q_mask[:, None, None, :]  # B, 1, 1, L1
            k_mask = k_mask[:, None, :, None]  # B, 1, L2, 1
            attn_mask = q_mask & k_mask  # B, 1, L1, L2
            attn = attn.masked_fill(attn_mask, float("-inf"))

        attn = attn.softmax(dim=-1)  # attn -> B, H, L1, L2
        attn = self.attn_drop(attn)
        x = attn @ v.transpose(
            -2, -3
        )  # v -> B, L2, H, C - B, H, L2, C    x-> B, H, L1, C
        context = x.transpose(1, 2)  # context -> B, H, L1, C - B, L1, H, C

        context = context.contiguous().view(b, s1, -1)

        out = self.out_proj(context)  # context.reshape - B, L1, -1
        out = self.proj_drop(out)

        out_tuple = (out,)

        return out_tuple


class Attention(nn.Module):
    """
    We rename some layer names to align with flash attention
    """

    def __init__(
        self,
        dim,
        num_heads,
        qkv_bias=True,
        qk_norm=False,
        attn_drop=0.0,
        proj_drop=0.0,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        assert self.dim % num_heads == 0, "dim should be divisible by num_heads"
        self.head_dim = self.dim // num_heads

        # This assertion is aligned with flash attention
        assert (
            self.head_dim % 8 == 0 and self.head_dim <= 128
        ), "Only support head_dim <= 128 and divisible by 8"
        self.scale = self.head_dim**-0.5

        # qkv --> Wqkv
        self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        # TODO: eps should be 1 / 65530 if using fp16
        self.q_norm = (
            norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
            if qk_norm
            else nn.Identity()
        )
        self.k_norm = (
            norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
            if qk_norm
            else nn.Identity()
        )
        self.attn_drop = nn.Dropout(attn_drop)
        self.out_proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, freqs_cis=None, key_padding_mask=None, attn_mask=None):
        B, N, C = x.shape
        qkv = (
            self.Wqkv(x)
            .reshape(B, N, 3, self.num_heads, self.head_dim)
            .permute(2, 0, 3, 1, 4)
        )  # [3, b, h, s, d]
        q, k, v = qkv.unbind(0)  # [b, h, s, d]
        q = self.q_norm(q)  # [b, h, s, d]
        k = self.k_norm(k)  # [b, h, s, d]

        # Apply RoPE if needed
        if freqs_cis is not None:
            qq, kk = apply_rotary_pos_emb(q, k, freqs_cis)

            assert (
                qq.shape == q.shape and kk.shape == k.shape
            ), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}"
            q, k = qq, kk

        q = q * self.scale
        attn = q @ k.transpose(-2, -1)  # [b, h, s, d] @ [b, h, d, s]

        if key_padding_mask is not None:
            assert (
                key_padding_mask.shape == (B, N)
            ), f"expecting key_padding_mask shape of {(B, N)}, but got {key_padding_mask.shape}"

            key_padding_mask = key_padding_mask[:, None, None, :]  # [b, 1, 1, s]
            attn = attn.masked_fill(key_padding_mask, float("-inf"))

        attn = attn.softmax(dim=-1)  # [b, h, s, s]
        attn = self.attn_drop(attn)
        x = attn @ v  # [b, h, s, d]

        x = x.transpose(1, 2).reshape(B, N, C)  # [b, s, h, d]
        x = self.out_proj(x)
        x = self.proj_drop(x)

        return x


class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """

    def __init__(
        self,
        hidden_size,
        num_heads,
        c_emb_size=None,
        text_states_dim=None,
        mlp_ratio=1.0,
        norm_type="layer",
        qk_norm=False,
        qkv_bias=True,
        elementwise_affine=True,
        dropout=0.1,
        **block_kwargs,
    ):
        super().__init__()

        self.c_emb_size = c_emb_size

        if text_states_dim is None:
            text_states_dim = hidden_size

        if norm_type == "layer":
            norm_layer = FP32_Layernorm
        elif norm_type == "rms":
            norm_layer = RMSNorm
        else:
            raise ValueError(f"Unknown norm_type: {norm_type}")

        # ========================= Self-Attention =========================
        self.norm1 = norm_layer(
            hidden_size, elementwise_affine=elementwise_affine, eps=1e-6
        )

        # self.attn1 = Attention(
        #     hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm
        # )
        # self.attn1 = FlashSelfMHAModified(
        #     hidden_size,
        #     num_heads=num_heads,
        #     qkv_bias=True,
        #     qk_norm=qk_norm,
        #     **block_kwargs,
        # )
        self.attn1 = nn.MultiheadAttention(
            hidden_size,
            num_heads,
            dropout=dropout,
            # batch_first=True,
        )

        # ========================= FFN =========================
        self.norm2 = norm_layer(
            hidden_size, elementwise_affine=elementwise_affine, eps=1e-6
        )
        mlp_hidden_dim = int(hidden_size * mlp_ratio)

        # approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(
            in_features=hidden_size,
            hidden_features=mlp_hidden_dim,
            # act_layer=approx_gelu,
            drop=0,
        )

        if c_emb_size:
            self.default_modulation = nn.Sequential(
                FP32_SiLU(), nn.Linear(c_emb_size, hidden_size, bias=True)
            )

        # self.attn2 = CrossAttention(
        #     hidden_size,
        #     text_states_dim,
        #     qkv_bias=qkv_bias,
        #     num_heads=num_heads,
        #     **block_kwargs,
        # )
        # self.attn2 = FlashCrossMHAModified(
        #     hidden_size,
        #     text_states_dim,
        #     num_heads=num_heads,
        #     qk_norm=qk_norm,
        #     **block_kwargs,
        # )

        self.attn2 = nn.MultiheadAttention(
            hidden_size,
            num_heads,
            kdim=text_states_dim,
            vdim=text_states_dim,
            dropout=dropout,
            # batch_first=True,
        )

        self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)

    def forward(
        self,
        x,
        c=None,
        text_states=None,
        freqs_cis=None,
        key_padding_mask=None,
    ):
        # Self-Attention
        x = self.norm1(x)
        # if self.c_emb_size:
        #     shift_msa = self.default_modulation(c).unsqueeze(dim=1)
        #     attn_inputs = (
        #         x + shift_msa,
        #         freqs_cis,
        #         key_padding_mask,
        #     )
        # else:
        #     attn_inputs = (
        #         x,
        #         x,
        #         x,
        #         freqs_cis,
        #         key_padding_mask,
        #     )

        # x = x + self.attn1(*attn_inputs)[0]
        x = x + self.attn1(x, x, x, key_padding_mask=key_padding_mask)[0]

        # Cross-Attention
        # cross_inputs = (
        #     self.norm3(x),
        #     text_states,
        #     text_states,
        #     freqs_cis,
        #     key_padding_mask,
        #     key_padding_mask,
        # )
        # x = x + self.attn2(*cross_inputs)[0]
        x = (
            x
            + self.attn2(
                self.norm3(x),
                text_states,
                text_states,
                key_padding_mask=key_padding_mask,
            )[0]
        )

        # FFN Layer
        mlp_inputs = self.norm2(x)
        x = x + self.mlp(mlp_inputs)

        return x
