import math
from dataclasses import dataclass
from typing import Optional, Union, Tuple

from random import random

from einops import rearrange, repeat

import torch
from torch import nn
from torch.special import expm1
from torch.nn import functional as F

from transformers.activations import ACT2FN
from transformers.utils import ModelOutput


@dataclass
class DiffusionOutput(ModelOutput):
    logits: torch.FloatTensor
    denoised: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None


class DiffusionConfig:
    def __init__(
        self,
        num_embeddings: int = 30000,
        embedding_size: int = 256,
        hidden_size: int = 1024,
        num_attention_heads: int = 8,
        num_hidden_layers: int = 8,
        hidden_act="gelu_new",
        time_embedding_size=128,
        initializer_range=0.02,
        embedding_initializer_range=0.001,
        max_position_embeddings=64,
        num_bins=100,
        time_wrapping_epsilon=5e-4,
    ):
        self.num_embeddings = num_embeddings
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.hidden_act = hidden_act
        self.time_embedding_size = time_embedding_size
        self.initializer_range = initializer_range
        self.embedding_initializer_range = embedding_initializer_range
        self.max_position_embeddings = max_position_embeddings
        self.num_bins = num_bins
        self.time_wrapping_epsilon = time_wrapping_epsilon


class KerpleLog(torch.nn.Module):
    """Kernelized T5 Relative Position Bias parallelized in the heads dimension"""

    def __init__(
        self,
        config,
    ):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.eps = 1e-2

        # Allocate weights and initialize.
        # The kernel has the form -p*log(1+a*|m-n|)
        def get_parameter(scale, init_method):
            if init_method == "ones":
                return Parameter(
                    torch.ones(
                        self.num_heads,
                    )[:, None, None]
                    * scale
                )
            elif init_method == "uniform":
                return Parameter(
                    torch.rand(
                        self.num_heads,
                    )[:, None, None]
                    * scale
                )

        self.bias_p = get_parameter(2, "uniform")
        self.bias_a = get_parameter(1, "uniform")

        self.cached_matrix = None
        self.cached_seq_len = None

    def stats(self):
        def get_stats(name, obj):
            return {
                name + "_mean": obj.mean().detach().cpu(),
                name + "_std": obj.std().detach().cpu(),
                name + "_max": obj.max().detach().cpu(),
                name + "_min": obj.min().detach().cpu(),
            }

        dd = {}
        self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
        dd.update(get_stats("bias_a", self.bias_a))
        self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
        dd.update(get_stats("bias_p", self.bias_p))
        return dd

    def forward(self, x):
        # [b, np, sq, sk]
        seq_len_q = x.shape[-2]
        seq_len_k = x.shape[-1]
        if self.cached_seq_len != seq_len_k:
            diff = torch.tril(
                torch.arange(seq_len_k, device=x.device)
                .view(seq_len_k, 1)
                .repeat(1, seq_len_k)
                + torch.arange(0, -seq_len_k, -1, device=x.device)
            )
            diff = diff.to(x.dtype)
            self.cached_seq_len = seq_len_k
            self.cached_matrix = diff
        else:
            diff = self.cached_matrix

        self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
        self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
        bias = -self.bias_p * torch.log(1 + self.bias_a * diff)  # log kernel

        if seq_len_q != seq_len_k:
            # In the train case x has dimensionality [b, np, sq, sk] with sq == sk
            # The number of query tokens is equal to the number of key tokens
            # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
            # In this case we use the appropriate token index of the cache matrix.
            # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
            assert (
                seq_len_q == 1
            ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"

            if type(bias) != float:
                # seq_len_k - 1 points to the last token index in the current inference batch.
                bias = bias[:, seq_len_k - 1, :].view(bias.shape[0], 1, bias.shape[2])

        return x + bias


class LearnedSinusoidalPosEmb(nn.Module):
    """following @crowsonkb 's lead with learned sinusoidal pos emb"""

    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, config):
        super().__init__()
        dim = config.time_embedding_size
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))
        self.projector = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
        )

    def forward(self, x):
        x = rearrange(x, "b -> b 1")
        freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        return self.projector(fouriered)


class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.num_attention_heads
        self.register_buffer("masked_bias", torch.tensor(-1e9))
        self.rpe = KerpleLog(config=config)
        self.norm_factor = torch.sqrt(
            torch.tensor(self.head_size, dtype=torch.float32)
        ).to(torch.get_default_dtype())
        self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        layer_past=None,
        use_cache=False,
        output_attentions=False,
    ):
        has_layer_past = layer_past is not None

        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (np * 3 * head_size)]
        qkv = self.query_key_value(hidden_states)

        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3 * head_size]
        new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
        qkv = qkv.view(*new_qkv_shape)

        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
        key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
        value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)

        # Compute token offset for rotary embeddings (when decoding)
        seq_len = key.shape[-2]
        offset = 0
        if has_layer_past:
            offset = layer_past[0].shape[-2]
            seq_len += offset

        # Cache QKV values
        if has_layer_past:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        present = (key, value) if use_cache else None

        # Compute attention
        attn_output, attn_weights = self._attn(
            query, key, value, attention_mask, head_mask
        )

        # Reshape outputs
        attn_output = self._merge_heads(
            attn_output, self.num_attention_heads, self.head_size
        )
        attn_output = self.dense(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

    @classmethod
    def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Splits hidden dim into attn_head_size and num_attention_heads
        """
        # tensor: [bs, seq_len, hidden_size]
        new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(new_shape)
        # -> [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor

    @classmethod
    def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden dim
        """
        # tensor [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3).contiguous()
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(
            tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
        )
        # -> [bs, seq_len, hidden_size]
        return tensor

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
        # compute causal mask from causal mask buffer
        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        query = query.view(
            batch_size * num_attention_heads, query_length, attn_head_size
        )
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.zeros(
            batch_size * num_attention_heads,
            query_length,
            key_length,
            dtype=query.dtype,
            device=key.device,
        )
        attn_scores = torch.baddbmm(
            attn_scores,
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=(
                torch.tensor(
                    1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device
                )
                / self.norm_factor
            ),
        )
        attn_scores = attn_scores.view(
            batch_size, num_attention_heads, query_length, key_length
        )

        attn_scores = self.rpe(attn_scores)

        mask_value = torch.finfo(attn_scores.dtype).min
        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(
            attn_scores.device
        )

        if attention_mask is not None:
            # Apply the attention mask
            attn_scores = attn_scores + attention_mask

        attn_weights = nn.functional.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.to(value.dtype)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights


class ConditionalLayerNorm(nn.Module):
    def __init__(self, config):
        super(ConditionalLayerNorm, self).__init__()
        self.gamma_cond = nn.Linear(config.time_embedding_size, config.hidden_size)
        self.beta_cond = nn.Linear(config.time_embedding_size, config.hidden_size)
        self.epsilon = 1e-5

    def init_weights(self):
        torch.nn.init.constant_(self.gamma_cond.weight, 0.0)
        torch.nn.init.constant_(self.gamma_cond.bias, 1.0)
        torch.nn.init.constant_(self.beta_cond.weight, 0.0)
        torch.nn.init.constant_(self.beta_cond.bias, 0.0)

    def forward(self, time_embedding, hidden_state):
        mean = hidden_state.mean(dim=-1, keepdim=True)
        var = ((hidden_state - mean) ** 2).mean(dim=-1, keepdim=True)
        std = (var + self.epsilon).sqrt()
        normalized = (hidden_state - mean) / std
        c_alpha = self.gamma_cond(time_embedding).view(hidden_state.size(0), 1, -1)
        c_beta = self.beta_cond(time_embedding).view(hidden_state.size(0), 1, -1)
        return normalized * c_alpha + c_beta


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conditional_layer_norm_1 = ConditionalLayerNorm(config)
        self.attention = Attention(config)
        self.conditional_layer_norm_2 = ConditionalLayerNorm(config)
        self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4)
        self.act = ACT2FN[config.hidden_act]
        self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size)

    def forward(self, hidden_states, time_embeddings, attention_mask=None):
        cond_1 = self.conditional_layer_norm_1(time_embeddings, hidden_states)
        attn_res = self.attention(cond_1)[0]
        hidden_states = hidden_states + attn_res
        cond_2 = self.conditional_layer_norm_2(time_embeddings, hidden_states)
        projected_up = self.act(self.up_proj(cond_2))
        projected_down = self.down_proj(projected_up)
        hidden_states = hidden_states + projected_down
        return hidden_states


class SimplifiedDiffusionTransformerKerple(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config.num_embeddings, config.embedding_size)
        self.time_emmbeddings = LearnedSinusoidalPosEmb(config)
        self.emb_to_hidden = nn.Linear(
            config.embedding_size * 2, config.hidden_size
        )  # *4 stands for conditioning_mask, token_embedding_c, token_embedding_x and self_condtioning
        self.encoders = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.classifier = nn.Linear(
            config.hidden_size, config.num_embeddings, bias=False
        )

        for m in self.modules():
            self._init_weights(m)

        for e in self.encoders:
            e.conditional_layer_norm_1.init_weights()
            e.conditional_layer_norm_2.init_weights()

    def get_embeddings(self, input_ids):
        embeddings = self.embeddings(input_ids)
        embeddings = F.normalize(embeddings, dim=-1) * math.sqrt(
            self.config.embedding_size
        )  # ~ N(0, 1)
        return embeddings

    def get_noisy_embeddings(
        self, timestamps, conditioning_mask, input_ids, drop_conditioning_embeddings
    ):
        embeddings = self.get_embeddings(input_ids)
        noise = (
            timestamps.view(-1, 1, 1)
            * torch.randn_like(embeddings)
            * (~conditioning_mask).unsqueeze(-1)
        )

        noisy_hidden_states = embeddings + noise
        if drop_conditioning_embeddings:
            noisy_hidden_states = noisy_hidden_states * (~conditioning_mask).unsqueeze(
                -1
            )

        scale = 1 / torch.sqrt(timestamps**2 + 1)
        noisy_hidden_states = noisy_hidden_states * scale.view(
            -1, 1, 1
        )  # renormalize embeddings again. See Section 6.1
        return noisy_hidden_states

    def forward(
        self,
        timestamps: torch.Tensor,
        conditioning_mask: Optional[torch.BoolTensor],
        input_ids: Optional[torch.LongTensor] = None,
        noisy_hidden_states: Optional[torch.FloatTensor] = None,
        self_conditioning_hidden_states: Optional[torch.FloatTensor] = None,
        output_denoised: bool = False,
        drop_conditioning_embeddings: bool = False,
        self_conditioning: bool = False,
        output_loss: bool = True,
        weighted_loss: bool = False,
        t_max: float = None,
        attention_mask: torch.Tensor = None,
    ) -> DiffusionOutput:
        if noisy_hidden_states is None:
            noisy_hidden_states = self.get_noisy_embeddings(
                timestamps=timestamps,
                conditioning_mask=conditioning_mask,
                input_ids=input_ids,
                drop_conditioning_embeddings=drop_conditioning_embeddings,
            )

        if self_conditioning_hidden_states is None:
            self_conditioning_hidden_states = torch.zeros_like(noisy_hidden_states)

        batch_size = noisy_hidden_states.size(0)

        if self_conditioning:
            with torch.no_grad():
                s_outputs = self.forward(
                    timestamps=timestamps,
                    conditioning_mask=conditioning_mask,
                    noisy_hidden_states=noisy_hidden_states,
                    output_loss=False,
                    output_denoised=True,
                    attention_mask=attention_mask,
                )
                self_conditioning_hidden_states = s_outputs.denoised * (
                    ~conditioning_mask.unsqueeze(-1)
                )

                # Attention mask.
        if attention_mask is not None:
            if batch_size <= 0:
                raise ValueError("batch_size has to be defined and > 0")
            attention_mask = attention_mask.view(batch_size, -1)
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, None, None, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and the dtype's smallest value for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

        embeddings_cat = torch.cat(
            (
                noisy_hidden_states,
                self_conditioning_hidden_states,
            ),
            dim=-1,
        )

        hidden_states = self.emb_to_hidden(embeddings_cat)
        time_embeddings = self.time_emmbeddings(timestamps)
        for encoder in self.encoders:
            hidden_states = encoder(
                hidden_states=hidden_states,
                time_embeddings=time_embeddings,
                attention_mask=attention_mask,
            )

        logits = self.classifier(hidden_states)
        outputs = DiffusionOutput(logits=logits)

        loss = None
        if input_ids is not None:
            if weighted_loss:
                ce_not_weighted = torch.zeros(
                    input_ids.size(0), device=input_ids.device
                )
                for i in range(input_ids.size(0)):
                    ce_not_weighted[i] = F.cross_entropy(logits[i], input_ids[i])
                ce_weighted = ce_not_weighted * timestamps / t_max
                loss = ce_weighted.mean()
            else:
                targets = input_ids.view(-1)
                logits_r = logits.view(targets.size(0), -1)[~conditioning_mask.view(-1)]
                targets = targets[~conditioning_mask.view(-1)]
                loss = F.cross_entropy(logits_r, targets)

        denoised = None
        if output_denoised:
            with torch.no_grad():
                probs = F.softmax(logits, dim=-1)
                normalized_embeddings = F.normalize(
                    self.embeddings.weight.data, dim=-1
                ) * math.sqrt(
                    self.config.embedding_size
                )  # ~ N(0, 1)
                denoised = probs @ normalized_embeddings  # still normalized
        return DiffusionOutput(logits=logits, loss=loss, denoised=denoised)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # if module.padding_idx is not None:
            # module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


if __name__ == "__main__":
    config = DiffusionConfig()
    model = SimplifiedDiffusionTransformerKerple(config)
