import math
from dataclasses import dataclass
from typing import Optional, Union, Tuple

from random import random

from einops import rearrange, repeat

import torch
from torch import nn
from torch.special import expm1
from torch.nn import functional as F

from transformers.activations import ACT2FN
from transformers.utils import ModelOutput


@dataclass
class VariationalDiffusionOutput(ModelOutput):
    logits: torch.FloatTensor
    denoised: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None
    kl_loss: Optional[torch.FloatTensor] = None


class DiffusionConfig:
    def __init__(
        self,
        num_embeddings: int = 30000,
        embedding_size: int = 256,
        hidden_size: int = 1024,
        num_attention_heads: int = 8,
        num_hidden_layers: int = 8,
        rotary_pct=0.25,
        hidden_act="gelu_new",
        time_embedding_size=128,
        initializer_range=0.02,
        embedding_initializer_range=0.001,
        max_position_embeddings=64,
        rotary_emb_base=10000,
        num_bins=100,
        time_wrapping_epsilon=5e-4,
    ):
        self.num_embeddings = num_embeddings
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.rotary_pct = rotary_pct
        self.hidden_act = hidden_act
        self.time_embedding_size = time_embedding_size
        self.initializer_range = initializer_range
        self.embedding_initializer_range = embedding_initializer_range
        self.max_position_embeddings = max_position_embeddings
        self.rotary_emb_base = rotary_emb_base
        self.num_bins = num_bins
        self.time_wrapping_epsilon = time_wrapping_epsilon


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
    cos = cos[..., offset : q.shape[-2] + offset, :]
    sin = sin[..., offset : q.shape[-2] + offset, :]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(
            self.max_seq_len_cached,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(
                self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
            )
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]
        return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[
            :seq_len, ...
        ].to(x.device)


class LearnedSinusoidalPosEmb(nn.Module):
    """following @crowsonkb 's lead with learned sinusoidal pos emb"""

    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, config):
        super().__init__()
        dim = config.time_embedding_size
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))
        self.projector = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
        )

    def forward(self, x):
        x = rearrange(x, "b -> b 1")
        freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        return self.projector(fouriered)


class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.num_attention_heads
        self.rotary_ndims = int(self.head_size * config.rotary_pct)
        max_positions = config.max_position_embeddings
        self.register_buffer("masked_bias", torch.tensor(-1e9))
        self.rotary_emb = RotaryEmbedding(
            self.rotary_ndims,
            config.max_position_embeddings,
            base=config.rotary_emb_base,
        )
        self.norm_factor = torch.sqrt(
            torch.tensor(self.head_size, dtype=torch.float32)
        ).to(torch.get_default_dtype())
        self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        layer_past=None,
        use_cache=False,
        output_attentions=False,
    ):
        has_layer_past = layer_past is not None

        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (np * 3 * head_size)]
        qkv = self.query_key_value(hidden_states)

        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3 * head_size]
        new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
        qkv = qkv.view(*new_qkv_shape)

        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
        key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
        value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)

        # Compute rotary embeddings on rotary_ndims
        query_rot = query[..., : self.rotary_ndims]
        query_pass = query[..., self.rotary_ndims :]
        key_rot = key[..., : self.rotary_ndims]
        key_pass = key[..., self.rotary_ndims :]

        # Compute token offset for rotary embeddings (when decoding)
        seq_len = key.shape[-2]
        offset = 0
        if has_layer_past:
            offset = layer_past[0].shape[-2]
            seq_len += offset
        cos, sin = self.rotary_emb(value, seq_len=seq_len)
        query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
        query = torch.cat((query, query_pass), dim=-1)
        key = torch.cat((key, key_pass), dim=-1)

        # Cache QKV values
        if has_layer_past:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        present = (key, value) if use_cache else None

        # Compute attention
        attn_output, attn_weights = self._attn(
            query, key, value, attention_mask, head_mask
        )

        # Reshape outputs
        attn_output = self._merge_heads(
            attn_output, self.num_attention_heads, self.head_size
        )
        attn_output = self.dense(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

    @classmethod
    def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Splits hidden dim into attn_head_size and num_attention_heads
        """
        # tensor: [bs, seq_len, hidden_size]
        new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(new_shape)
        # -> [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor

    @classmethod
    def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden dim
        """
        # tensor [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3).contiguous()
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(
            tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
        )
        # -> [bs, seq_len, hidden_size]
        return tensor

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
        # compute causal mask from causal mask buffer
        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        query = query.view(
            batch_size * num_attention_heads, query_length, attn_head_size
        )
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.zeros(
            batch_size * num_attention_heads,
            query_length,
            key_length,
            dtype=query.dtype,
            device=key.device,
        )
        attn_scores = torch.baddbmm(
            attn_scores,
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=(
                torch.tensor(
                    1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device
                )
                / self.norm_factor
            ),
        )
        attn_scores = attn_scores.view(
            batch_size, num_attention_heads, query_length, key_length
        )

        mask_value = torch.finfo(attn_scores.dtype).min
        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(
            attn_scores.device
        )

        if attention_mask is not None:
            # Apply the attention mask
            attn_scores = attn_scores + attention_mask

        attn_weights = nn.functional.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.to(value.dtype)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights


class ConditionalLayerNorm(nn.Module):
    def __init__(self, config):
        super(ConditionalLayerNorm, self).__init__()
        self.gamma_cond = nn.Linear(config.time_embedding_size, config.hidden_size)
        self.beta_cond = nn.Linear(config.time_embedding_size, config.hidden_size)
        self.epsilon = 1e-5

    def init_weights(self):
        torch.nn.init.constant_(self.gamma_cond.weight, 0.0)
        torch.nn.init.constant_(self.gamma_cond.bias, 1.0)
        torch.nn.init.constant_(self.beta_cond.weight, 0.0)
        torch.nn.init.constant_(self.beta_cond.bias, 0.0)

    def forward(self, time_embedding, hidden_state):
        mean = hidden_state.mean(dim=-1, keepdim=True)
        var = ((hidden_state - mean) ** 2).mean(dim=-1, keepdim=True)
        std = (var + self.epsilon).sqrt()
        normalized = (hidden_state - mean) / std
        c_alpha = self.gamma_cond(time_embedding).view(hidden_state.size(0), 1, -1)
        c_beta = self.beta_cond(time_embedding).view(hidden_state.size(0), 1, -1)
        return normalized * c_alpha + c_beta


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conditional_layer_norm_1 = ConditionalLayerNorm(config)
        self.attention = Attention(config)
        self.conditional_layer_norm_2 = ConditionalLayerNorm(config)
        self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4)
        self.act = ACT2FN[config.hidden_act]
        self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size)

    def forward(self, hidden_states, time_embeddings, attention_mask=None):
        cond_1 = self.conditional_layer_norm_1(time_embeddings, hidden_states)
        attn_res = self.attention(cond_1)[0]
        hidden_states = hidden_states + attn_res
        cond_2 = self.conditional_layer_norm_2(time_embeddings, hidden_states)
        projected_up = self.act(self.up_proj(cond_2))
        projected_down = self.down_proj(projected_up)
        hidden_states = hidden_states + projected_down
        return hidden_states


class VariationalDiffusionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config.num_embeddings, config.embedding_size)
        self.time_emmbeddings = LearnedSinusoidalPosEmb(config)
        self.emb_to_hidden = nn.Linear(
            config.embedding_size * 2, config.hidden_size
        )  # *4 stands for conditioning_mask, token_embedding_c, token_embedding_x and self_condtioning
        self.encoders = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.classifier = nn.Linear(
            config.hidden_size, config.num_embeddings, bias=False
        )

        self.var_predictor = nn.Sequential(
            nn.Linear(config.embedding_size, config.embedding_size * 2),
            ACT2FN[config.hidden_act],
            nn.Linear(config.embedding_size * 2, config.embedding_size * 2),
            ACT2FN[config.hidden_act],
            nn.Linear(config.embedding_size * 2, config.embedding_size),
        )

        for m in self.modules():
            self._init_weights(m)

        for e in self.encoders:
            e.conditional_layer_norm_1.init_weights()
            e.conditional_layer_norm_2.init_weights()

    def get_embeddings(
        self, input_ids, is_valid=False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        embeddings = self.embeddings(input_ids)
        if is_valid:
            return embeddings, -20 * torch.ones_like(embeddings)
        mu = embeddings
        sigma = self.get_sigma_e(embeddings)
        return mu, sigma

    def get_noisy_embeddings(
        self,
        timestamps,
        conditioning_mask,
        input_ids,
        drop_conditioning_embeddings,
        output_kl_loss: bool = False,
        is_valid: bool = True,
    ):
        embeddings, log_sigmas = self.get_embeddings(input_ids, is_valid=is_valid)
        noise = torch.randn_like(embeddings) * (~conditioning_mask).unsqueeze(-1)
        log_sigmas = torch.clamp(log_sigmas, -20, 3)
        sum_sigma = timestamps.view(-1, 1, 1) + torch.exp(log_sigmas)
        scale = 1 / torch.sqrt(
            torch.exp(log_sigmas).pow(2) + timestamps.view(-1, 1, 1).pow(2)
        )

        noisy_hidden_states = (embeddings + noise * sum_sigma) * scale
        if drop_conditioning_embeddings:
            noisy_hidden_states = noisy_hidden_states * (~conditioning_mask).unsqueeze(
                -1
            )

        if output_kl_loss:
            loss_kld = (
                -0.5
                * torch.sum(
                    1 + log_sigmas - embeddings.pow(2) - log_sigmas.exp(), dim=1
                )
            ).mean()
            return noisy_hidden_states, loss_kld

        return noisy_hidden_states, None

    def get_sigma_e(self, e):
        return self.var_predictor(e)

    def forward(
        self,
        timestamps: torch.Tensor,
        conditioning_mask: Optional[torch.BoolTensor],
        input_ids: Optional[torch.LongTensor] = None,
        noisy_hidden_states: Optional[torch.FloatTensor] = None,
        self_conditioning_hidden_states: Optional[torch.FloatTensor] = None,
        output_denoised: bool = False,
        drop_conditioning_embeddings: bool = False,
        self_conditioning: bool = False,
        output_loss: bool = True,
        weighted_loss: bool = False,
        t_max: float = None,
        attention_mask: torch.Tensor = None,
        beta=0.5,
    ) -> VariationalDiffusionOutput:
        kl_loss = None
        if self.training:
            is_valid = False
            output_kl_loss = True
        else:
            is_valid = True
            output_kl_loss = False
        if noisy_hidden_states is None:
            noisy_hidden_states, kl_loss = self.get_noisy_embeddings(
                timestamps=timestamps,
                conditioning_mask=conditioning_mask,
                input_ids=input_ids,
                drop_conditioning_embeddings=drop_conditioning_embeddings,
                output_kl_loss=output_kl_loss,
                is_valid=is_valid,
            )

        if self_conditioning_hidden_states is None:
            self_conditioning_hidden_states = torch.zeros_like(noisy_hidden_states)

        batch_size = noisy_hidden_states.size(0)

        if self_conditioning:
            with torch.no_grad():
                s_outputs = self.forward(
                    timestamps=timestamps,
                    conditioning_mask=conditioning_mask,
                    noisy_hidden_states=noisy_hidden_states,
                    output_loss=False,
                    output_denoised=True,
                    attention_mask=attention_mask,
                )
                self_conditioning_hidden_states = s_outputs.denoised * (
                    ~conditioning_mask.unsqueeze(-1)
                )

                # Attention mask.
        if attention_mask is not None:
            if batch_size <= 0:
                raise ValueError("batch_size has to be defined and > 0")
            attention_mask = attention_mask.view(batch_size, -1)
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, None, None, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and the dtype's smallest value for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

        embeddings_cat = torch.cat(
            (
                noisy_hidden_states,
                self_conditioning_hidden_states,
            ),
            dim=-1,
        )

        hidden_states = self.emb_to_hidden(embeddings_cat)
        time_embeddings = self.time_emmbeddings(timestamps)
        for encoder in self.encoders:
            hidden_states = encoder(
                hidden_states=hidden_states,
                time_embeddings=time_embeddings,
                attention_mask=attention_mask,
            )

        logits = self.classifier(hidden_states)
        outputs = VariationalDiffusionOutput(logits=logits)

        loss = None
        if input_ids is not None:
            if weighted_loss:
                ce_not_weighted = torch.zeros(
                    input_ids.size(0), device=input_ids.device
                )
                for i in range(input_ids.size(0)):
                    ce_not_weighted[i] = F.cross_entropy(logits[i], input_ids[i])
                ce_weighted = ce_not_weighted * timestamps / t_max
                loss = ce_weighted.mean()
            else:
                targets = input_ids.view(-1)
                logits_r = logits.view(targets.size(0), -1)[~conditioning_mask.view(-1)]
                targets = targets[~conditioning_mask.view(-1)]
                loss = F.cross_entropy(logits_r, targets)
                if output_kl_loss:
                    loss = loss + beta * kl_loss

        denoised = None
        if output_denoised:
            with torch.no_grad():
                probs = F.softmax(logits, dim=-1)
                normalized_embeddings = self.embeddings.weight.data
                denoised = probs @ normalized_embeddings  # still normalized
        return VariationalDiffusionOutput(
            logits=logits, loss=loss, denoised=denoised, kl_loss=kl_loss
        )

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # if module.padding_idx is not None:
            # module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


if __name__ == "__main__":
    config = DiffusionConfig()
    model = VariationalDiffusionTransformer(config)
    model(
        input_ids=torch.ones(10, 10, dtype=torch.long),
        timestamps=torch.ones(10),
        conditioning_mask=torch.ones(10, 10, dtype=torch.bool),
    )
