# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
# Part of this implementation is adapted from https://github.com/facebookresearch/DiT
# which is released under NonCommercial-4.0 license
# Part of this implementation is adapted from https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
# which is released under MIT license
# Part of this implementation is adapted from https://github.com/louaaron/Score-Entropy-Discrete-Diffusion
# which is released under MIT license

import math
from typing import Optional

import torch
import torch.nn.functional as F

from einops import rearrange
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig

from torch import nn, Tensor

from . import rotary


def bias_dropout_add_scale(
    x: Tensor, scale: Tensor, residual: Optional[Tensor], prob: float, training: bool
) -> Tensor:
    return residual + scale * F.dropout(x, p=prob, training=training)


def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
    return x * (1 + scale) + shift


class LayerNorm(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.weight = nn.Parameter(torch.ones([dim]))
        self.dim = dim

    def forward(self, x: Tensor) -> Tensor:
        with torch.amp.autocast("cuda", enabled=False):
            x = F.layer_norm(x.float(), [self.dim])

        return x * self.weight[None, None, :]


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """

    def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(time: Tensor, dim: int, max_period: int = 10000) -> Tensor:
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period)
            * torch.arange(start=0, end=half)
            / half
        ).to(device=time.device)
        args = time[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat(
                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
            )
        return embedding

    def forward(self, time: Tensor) -> Tensor:
        # ToDo: Type should be automatic
        t_freq = self.timestep_embedding(time=time, dim=self.frequency_embedding_size).to(dtype=time.dtype)
        t_emb = self.mlp(t_freq)
        return t_emb


class DDiTBlock(nn.Module):
    def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
        super().__init__()
        assert dim % n_heads == 0, "dim must be divisible by n_heads"
        self.n_heads = n_heads
        self.dim = dim
        self.dropout = dropout
        self.head_dim = dim // n_heads
        
        self.norm1 = LayerNorm(dim)
        self.qw = nn.Linear(dim, dim, bias=False)
        self.kw = nn.Linear(dim, dim, bias=False)
        self.vw = nn.Linear(dim, dim, bias=False)
        self.attn_out = nn.Linear(dim, dim, bias=False)

        self.norm2 = LayerNorm(dim)
        self.cross_attn = CrossAttention(dim=dim, n_heads=n_heads, dropout=dropout)
        self.norm_cross = LayerNorm(dim)

        self.norm3 = LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_ratio * dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_ratio * dim, dim, bias=True),
        )

        self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x: Tensor, rotary_cos_sin: Tensor, c: Tensor, img_tokens: Tensor) -> Tensor:
        B, T = x.shape[0], x.shape[1]

        (
            shift_msa,
            scale_msa,
            gate_msa,
            shift_mlp,
            scale_mlp,
            gate_mlp,
        ) = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)

        # Self-attention (with rotary)
        x_skip = x
        x = modulate(self.norm1(x), shift=shift_msa, scale=scale_msa)

        q = self.qw(x)
        k = self.kw(x)
        v = self.vw(x)

        q, k, v = (item.view(B, T, self.n_heads, self.head_dim) for item in (q, k, v))

        with torch.amp.autocast("cuda", enabled=False):
            cos, sin = rotary_cos_sin
            q = rotary.apply_rotary_emb_torch(q.float(), cos, sin).to(q.dtype)
            k = rotary.apply_rotary_emb_torch(k.float(), cos, sin).to(k.dtype)

        q, k, v = (item.transpose(1, 2) for item in (q, k, v))
        attn = F.scaled_dot_product_attention(q, k, v)
        attn = rearrange(attn, "b h s d -> b s (h d)", b=B)

        x = bias_dropout_add_scale(
            x=self.attn_out(attn),
            scale=gate_msa,
            residual=x_skip,
            prob=self.dropout,
            training=self.training,
        )

        # Cross-attention: tokens attend to image features
        x = x + self.cross_attn(self.norm_cross(x), img_tokens)

        # MLP + AdaLN
        x = bias_dropout_add_scale(
            x=self.mlp(modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)),
            scale=gate_mlp,
            residual=x,
            prob=self.dropout,
            training=self.training,
        )

        return x
        
class CrossAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context):
        # text tokens (x): [B, T, D], image tokens (context): [B, S, D], D model hidden size
        q_proj_out = self.q_proj(x)
        k_proj_out = self.k_proj(context)
        v_proj_out = self.v_proj(context)
        B, T, D_q = q_proj_out.shape
        _, S, D_k = k_proj_out.shape

        assert D_k == D_q, f"Query and key projected dims mismatch: {D_q} vs {D_k}"
        assert D_q == self.n_heads * self.head_dim, f"q_proj_our las dim ({D_q}) DNE heads * head_dim"

        q = q_proj_out.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # [B, H, T, D_head]
        k = k_proj_out.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        v = v_proj_out.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)

        attn_weights = (q @ k.transpose(-2, -1)) * self.scale
        attn_weights = attn_weights.softmax(dim=-1)
        attn_weights = self.dropout(attn_weights)

        out = attn_weights @ v  # [B, n_heads, T, head_dim]
        out = out.transpose(1, 2).contiguous().view(B, T, D_q) # [B, T, D]

        return self.out_proj(out)

class DDitFinalLayer(nn.Module):
    def __init__(self, hidden_size: int, out_channels: int, cond_dim: int):
        super().__init__()
        self.norm_final = LayerNorm(hidden_size)
        self.linear = nn.Linear(hidden_size, out_channels)
        self.linear.weight.data.zero_()
        self.linear.bias.data.zero_()

        self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
        self.adaLN_modulation.weight.data.zero_()
        self.adaLN_modulation.bias.data.zero_()

    def forward(self, x: Tensor, c: Tensor) -> Tensor:
        shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
        x = modulate(x=self.norm_final(x), shift=shift, scale=scale)
        x = self.linear(x)

        return x


class Transformer(nn.Module):
    def __init__(self, vocab_size: int, masked: bool, config: DictConfig):
        super().__init__()

        if isinstance(config, dict):
            config = OmegaConf.create(config)

        self.config = config
        self.vocab_size = vocab_size

        add_token = 1 if masked else 0

        self.vocab_embed = nn.Embedding(self.vocab_size + add_token, config.hidden_size)

        self.time_embedding = TimestepEmbedder(hidden_size=config.cond_dim)
        self.rotary_emb = rotary.Rotary(dim=config.hidden_size // config.n_heads)

        self.blocks = nn.ModuleList(
            [
                DDiTBlock(
                    dim=config.hidden_size,
                    n_heads=config.n_heads,
                    cond_dim=config.cond_dim,
                    dropout=config.dropout,
                )
                for _ in range(config.n_blocks)
            ]
        )

        self.output_layer = DDitFinalLayer(
            hidden_size=config.hidden_size,
            out_channels=vocab_size + add_token,
            cond_dim=config.cond_dim,
        )

    def forward(self, x_t: Tensor, time: Tensor, img_tokens: Tensor) -> Tensor:
        x = self.vocab_embed(x_t)               # [B, T, D]
        if time.dtype == torch.float16:
            self.time_embedding = self.time_embedding.half()
        else:
            self.time_embedding = self.time_embedding.float()
        t_emb = self.time_embedding(time=time)  # [B, cond_dim]
        c = F.silu(t_emb)

        rotary_cos_sin = self.rotary_emb(x)

        with torch.amp.autocast("cuda", dtype=torch.bfloat16):
            for block in self.blocks:
                x = block(x=x, rotary_cos_sin=rotary_cos_sin, c=c, img_tokens=img_tokens)

            x = self.output_layer(x=x, c=c)

        return x
