import math
from dataclasses import dataclass
from typing import Optional

from einops import rearrange

import torch
from torch import nn
from torch.nn import functional as F

from transformers.activations import ACT2FN
from transformers.utils import ModelOutput


@dataclass
class DiffusionOutput(ModelOutput):
    logits: torch.FloatTensor
    denoised: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None


class DiffusionConfig:
    def __init__(
        self,
        num_embeddings: int = 30000,
        embedding_size: int = 256,
        hidden_size: int = 1024,
        num_attention_heads: int = 8,
        num_hidden_layers: int = 8,
        rotary_pct=0.25,
        hidden_act="gelu_new",
        time_embedding_size=128,
        initializer_range=0.02,
        embedding_initializer_range=0.001,
        max_position_embeddings=64,
        rotary_emb_base=10000,
        num_bins=100,
        time_wrapping_epsilon=5e-4,
    ):
        self.num_embeddings = num_embeddings
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.rotary_pct = rotary_pct
        self.hidden_act = hidden_act
        self.time_embedding_size = time_embedding_size
        self.initializer_range = initializer_range
        self.embedding_initializer_range = embedding_initializer_range
        self.max_position_embeddings = max_position_embeddings
        self.rotary_emb_base = rotary_emb_base
        self.num_bins = num_bins
        self.time_wrapping_epsilon = time_wrapping_epsilon


class KerpleLog(torch.nn.Module):
    """Kernelized T5 Relative Position Bias parallelized in the heads dimension"""

    def __init__(
        self,
        config,
    ):
        super().__init__()

        self.num_heads = config.num_attention_heads
        self.eps = 1e-2

        # Allocate weights and initialize.
        # The kernel has the form -p*log(1+a*|m-n|)
        def get_parameter(scale, init_method):
            if init_method == "ones":
                return nn.Parameter(
                    torch.ones(
                        self.num_heads,
                    )[:, None, None]
                    * scale
                )
            elif init_method == "uniform":
                return nn.Parameter(
                    torch.rand(
                        self.num_heads,
                    )[:, None, None]
                    * scale
                )

        self.bias_p = get_parameter(2, "uniform")
        self.bias_a = get_parameter(1, "uniform")

        self.cached_matrix = None
        self.cached_seq_len = None

    def stats(self):
        def get_stats(name, obj):
            return {
                name + "_mean": obj.mean().detach().cpu(),
                name + "_std": obj.std().detach().cpu(),
                name + "_max": obj.max().detach().cpu(),
                name + "_min": obj.min().detach().cpu(),
            }

        dd = {}
        self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
        dd.update(get_stats("bias_a", self.bias_a))
        self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
        dd.update(get_stats("bias_p", self.bias_p))
        return dd

    def forward(self, x):
        # [b, np, sq, sk]
        seq_len_q = x.shape[-2]
        seq_len_k = x.shape[-1]
        if self.cached_seq_len != seq_len_k:
            diff = torch.tril(
                torch.arange(seq_len_k, device=x.device)
                .view(seq_len_k, 1)
                .repeat(1, seq_len_k)
                + torch.arange(0, -seq_len_k, -1, device=x.device)
            )
            diff = diff.to(x.dtype)
            self.cached_seq_len = seq_len_k
            self.cached_matrix = diff
        else:
            diff = self.cached_matrix

        self.bias_p.data = self.bias_p.data.clamp(min=self.eps)
        self.bias_a.data = self.bias_a.data.clamp(min=self.eps)
        bias = -self.bias_p * torch.log(1 + self.bias_a * diff)  # log kernel

        if seq_len_q != seq_len_k:
            # In the train case x has dimensionality [b, np, sq, sk] with sq == sk
            # The number of query tokens is equal to the number of key tokens
            # At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
            # In this case we use the appropriate token index of the cache matrix.
            # As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
            assert (
                seq_len_q == 1
            ), "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"

            if type(bias) != float:
                # seq_len_k - 1 points to the last token index in the current inference batch.
                bias = bias[:, seq_len_k - 1, :].view(bias.shape[0], 1, bias.shape[2])

        return x + bias


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
    cos = cos[..., offset : q.shape[-2] + offset, :]
    sin = sin[..., offset : q.shape[-2] + offset, :]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(
            self.max_seq_len_cached,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(
                self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
            )
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]
        return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[
            :seq_len, ...
        ].to(x.device)


class LearnedSinusoidalPosEmb(nn.Module):
    """following @crowsonkb 's lead with learned sinusoidal pos emb"""

    """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """

    def __init__(self, config):
        super().__init__()
        dim = config.time_embedding_size
        assert (dim % 2) == 0
        half_dim = dim // 2
        self.weights = nn.Parameter(torch.randn(half_dim))
        self.projector = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
            nn.Linear(dim, dim, bias=False),
            ACT2FN[config.hidden_act],
        )

    def forward(self, x):
        x = rearrange(x, "b -> b 1")
        freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
        fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        return self.projector(fouriered)


class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.num_attention_heads
        self.rotary_ndims = int(self.head_size * config.rotary_pct)
        max_positions = config.max_position_embeddings
        self.register_buffer("masked_bias", torch.tensor(-1e9))
        self.rotary_emb = RotaryEmbedding(
            self.rotary_ndims,
            config.max_position_embeddings,
            base=config.rotary_emb_base,
        )
        self.norm_factor = torch.sqrt(
            torch.tensor(self.head_size, dtype=torch.float32)
        ).to(torch.get_default_dtype())
        self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        layer_past=None,
        use_cache=False,
        output_attentions=False,
    ):
        has_layer_past = layer_past is not None

        # Compute QKV
        # Attention heads [batch, seq_len, hidden_size]
        #   --> [batch, seq_len, (np * 3 * head_size)]
        qkv = self.query_key_value(hidden_states)

        # [batch, seq_len, (num_heads * 3 * head_size)]
        #   --> [batch, seq_len, num_heads, 3 * head_size]
        new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
        qkv = qkv.view(*new_qkv_shape)

        # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
        key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
        value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)

        # Compute rotary embeddings on rotary_ndims
        query_rot = query[..., : self.rotary_ndims]
        query_pass = query[..., self.rotary_ndims :]
        key_rot = key[..., : self.rotary_ndims]
        key_pass = key[..., self.rotary_ndims :]

        # Compute token offset for rotary embeddings (when decoding)
        seq_len = key.shape[-2]
        offset = 0
        if has_layer_past:
            offset = layer_past[0].shape[-2]
            seq_len += offset
        cos, sin = self.rotary_emb(value, seq_len=seq_len)
        query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
        query = torch.cat((query, query_pass), dim=-1)
        key = torch.cat((key, key_pass), dim=-1)

        # Cache QKV values
        if has_layer_past:
            past_key = layer_past[0]
            past_value = layer_past[1]
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)
        present = (key, value) if use_cache else None

        # Compute attention
        attn_output, attn_weights = self._attn(
            query, key, value, attention_mask, head_mask
        )

        # Reshape outputs
        attn_output = self._merge_heads(
            attn_output, self.num_attention_heads, self.head_size
        )
        attn_output = self.dense(attn_output)

        outputs = (attn_output, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs

    @classmethod
    def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Splits hidden dim into attn_head_size and num_attention_heads
        """
        # tensor: [bs, seq_len, hidden_size]
        new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(new_shape)
        # -> [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor

    @classmethod
    def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
        """
        Merges attn_head_size dim and num_attn_heads dim into hidden dim
        """
        # tensor [bs, num_attention_heads, seq_len, attn_head_size]
        tensor = tensor.permute(0, 2, 1, 3).contiguous()
        # -> [bs, seq_len, num_attention_heads, attn_head_size]
        tensor = tensor.view(
            tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
        )
        # -> [bs, seq_len, hidden_size]
        return tensor

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
        # compute causal mask from causal mask buffer
        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        query = query.view(
            batch_size * num_attention_heads, query_length, attn_head_size
        )
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.zeros(
            batch_size * num_attention_heads,
            query_length,
            key_length,
            dtype=query.dtype,
            device=key.device,
        )
        attn_scores = torch.baddbmm(
            attn_scores,
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=(
                torch.tensor(
                    1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device
                )
                / self.norm_factor
            ),
        )
        attn_scores = attn_scores.view(
            batch_size, num_attention_heads, query_length, key_length
        )

        mask_value = torch.finfo(attn_scores.dtype).min
        # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
        # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(
            attn_scores.device
        )

        if attention_mask is not None:
            # Apply the attention mask
            attn_scores = attn_scores + attention_mask

        attn_weights = nn.functional.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.to(value.dtype)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights


class ConditionalLayerNorm(nn.Module):
    def __init__(self, config):
        super(ConditionalLayerNorm, self).__init__()
        self.gamma_cond = nn.Linear(config.time_embedding_size, config.hidden_size)
        self.beta_cond = nn.Linear(config.time_embedding_size, config.hidden_size)
        self.epsilon = 1e-5

    def init_weights(self):
        torch.nn.init.constant_(self.gamma_cond.weight, 0.0)
        torch.nn.init.constant_(self.gamma_cond.bias, 1.0)
        torch.nn.init.constant_(self.beta_cond.weight, 0.0)
        torch.nn.init.constant_(self.beta_cond.bias, 0.0)

    def forward(self, time_embedding, hidden_state):
        mean = hidden_state.mean(dim=-1, keepdim=True)
        var = ((hidden_state - mean) ** 2).mean(dim=-1, keepdim=True)
        std = (var + self.epsilon).sqrt()
        normalized = (hidden_state - mean) / std
        c_alpha = self.gamma_cond(time_embedding).view(hidden_state.size(0), 1, -1)
        c_beta = self.beta_cond(time_embedding).view(hidden_state.size(0), 1, -1)
        return normalized * c_alpha + c_beta


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conditional_layer_norm_1 = ConditionalLayerNorm(config)
        self.attention = Attention(config)
        self.conditional_layer_norm_2 = ConditionalLayerNorm(config)
        self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4)
        self.act = ACT2FN[config.hidden_act]
        self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size)

    def forward(self, hidden_states, time_embeddings, attention_mask=None):
        cond_1 = self.conditional_layer_norm_1(time_embeddings, hidden_states)
        attn_res = self.attention(cond_1)[0]
        hidden_states = hidden_states + attn_res
        cond_2 = self.conditional_layer_norm_2(time_embeddings, hidden_states)
        projected_up = self.act(self.up_proj(cond_2))
        projected_down = self.down_proj(projected_up)
        hidden_states = hidden_states + projected_down
        return hidden_states


class SimplifiedDiffusionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embeddings = nn.Embedding(config.num_embeddings, config.embedding_size)
        self.time_emmbeddings = LearnedSinusoidalPosEmb(config)
        self.emb_to_hidden = nn.Linear(
            config.embedding_size * 2, config.hidden_size
        )  # *4 stands for conditioning_mask, token_embedding_c, token_embedding_x and self_condtioning
        self.encoders = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_hidden_layers)]
        )
        self.classifier = nn.Linear(
            config.hidden_size, config.num_embeddings, bias=False
        )

        for m in self.modules():
            self._init_weights(m)

        for e in self.encoders:
            e.conditional_layer_norm_1.init_weights()
            e.conditional_layer_norm_2.init_weights()

    def get_embeddings(self, input_ids):
        embeddings = self.embeddings(input_ids)
        embeddings = F.normalize(embeddings, dim=-1) * math.sqrt(
            self.config.embedding_size
        )  # ~ N(0, 1)
        return embeddings

    def get_noisy_embeddings(
        self, timestamps, conditioning_mask, input_ids, drop_conditioning_embeddings
    ):
        embeddings = self.get_embeddings(input_ids)
        noise = (
            timestamps.view(-1, 1, 1)
            * torch.randn_like(embeddings)
            * (~conditioning_mask).unsqueeze(-1)
        )

        noisy_hidden_states = embeddings + noise
        if drop_conditioning_embeddings:
            noisy_hidden_states = noisy_hidden_states * (~conditioning_mask).unsqueeze(
                -1
            )

        scale = 1 / torch.sqrt(timestamps**2 + 1)
        noisy_hidden_states = noisy_hidden_states * scale.view(
            -1, 1, 1
        )  # renormalize embeddings again. See Section 6.1
        return noisy_hidden_states

    def forward(
        self,
        timestamps: torch.Tensor,
        conditioning_mask: Optional[torch.BoolTensor],
        input_ids: Optional[torch.LongTensor] = None,
        noisy_hidden_states: Optional[torch.FloatTensor] = None,
        self_conditioning_hidden_states: Optional[torch.FloatTensor] = None,
        output_denoised: bool = False,
        drop_conditioning_embeddings: bool = False,
        self_conditioning: bool = False,
        output_loss: bool = True,
        weighted_loss: bool = False,
        t_max: float = None,
        attention_mask: torch.Tensor = None,
    ) -> DiffusionOutput:
        if noisy_hidden_states is None:
            noisy_hidden_states = self.get_noisy_embeddings(
                timestamps=timestamps,
                conditioning_mask=conditioning_mask,
                input_ids=input_ids,
                drop_conditioning_embeddings=drop_conditioning_embeddings,
            )

        if self_conditioning_hidden_states is None:
            self_conditioning_hidden_states = torch.zeros_like(noisy_hidden_states)

        batch_size = noisy_hidden_states.size(0)

        if self_conditioning:
            with torch.no_grad():
                s_outputs = self.forward(
                    timestamps=timestamps,
                    conditioning_mask=conditioning_mask,
                    noisy_hidden_states=noisy_hidden_states,
                    output_loss=False,
                    output_denoised=True,
                    attention_mask=attention_mask,
                )
                self_conditioning_hidden_states = s_outputs.denoised * (
                    ~conditioning_mask.unsqueeze(-1)
                )

                # Attention mask.
        if attention_mask is not None:
            if batch_size <= 0:
                raise ValueError("batch_size has to be defined and > 0")
            attention_mask = attention_mask.view(batch_size, -1)
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, None, None, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and the dtype's smallest value for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

        embeddings_cat = torch.cat(
            (
                noisy_hidden_states,
                self_conditioning_hidden_states,
            ),
            dim=-1,
        )

        hidden_states = self.emb_to_hidden(embeddings_cat)
        time_embeddings = self.time_emmbeddings(timestamps)
        for encoder in self.encoders:
            hidden_states = encoder(
                hidden_states=hidden_states,
                time_embeddings=time_embeddings,
                attention_mask=attention_mask,
            )

        logits = self.classifier(hidden_states)
        outputs = DiffusionOutput(logits=logits)

        loss = None
        if input_ids is not None:
            if weighted_loss:
                ce_not_weighted = torch.zeros(
                    input_ids.size(0), device=input_ids.device
                )
                for i in range(input_ids.size(0)):
                    ce_not_weighted[i] = F.cross_entropy(logits[i], input_ids[i])
                ce_weighted = ce_not_weighted * timestamps / t_max
                loss = ce_weighted.mean()
            else:
                targets = input_ids.view(-1)
                logits_r = logits.view(targets.size(0), -1)[~conditioning_mask.view(-1)]
                targets = targets[~conditioning_mask.view(-1)]
                loss = F.cross_entropy(logits_r, targets)

        denoised = None
        if output_denoised:
            with torch.no_grad():
                probs = F.softmax(logits, dim=-1)
                normalized_embeddings = F.normalize(
                    self.embeddings.weight.data, dim=-1
                ) * math.sqrt(
                    self.config.embedding_size
                )  # ~ N(0, 1)
                denoised = probs @ normalized_embeddings  # still normalized
        return DiffusionOutput(logits=logits, loss=loss, denoised=denoised)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            # if module.padding_idx is not None:
            # module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
