"""
From: https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py
"""

import math
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from .normalization import get_norm_layer

#######################
### Basic Attention ###
#######################


class Attention(nn.Module):
    # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
    r"""
    A self / cross attention layer.
    """

    def __init__(
        self,
        query_dim: int,
        cross_attention_dim: Optional[int] = None,
        heads: int = 8,
        dim_head: Optional[int] = None,
        out_dim: int = None,
        dropout: float = 0.0,
        attn_drop: float = 0.0,
        bias: bool = False,
        qk_norm: Optional[str] = None,
        cross_attention_norm: Optional[str] = None,
        out_bias: bool = True,
        eps: float = 1e-5,
        elementwise_affine: bool = True,
        is_causal: bool = False,
        rescale_output_factor: float = 1.0,
        residual_connection: bool = False,
    ):
        super().__init__()

        # Compute dims
        if out_dim is None:
            out_dim = dim_head * heads if dim_head is not None else query_dim
        inner_dim = out_dim if dim_head is None else dim_head * heads
        is_cross_attention = cross_attention_dim is not None
        cross_attention_dim = cross_attention_dim or query_dim
        if dim_head is None:
            dim_head = inner_dim // heads

        if out_dim is not None:
            assert heads * dim_head == inner_dim

        # Args
        self.is_causal = is_causal
        self.heads = heads
        self.inner_dim = inner_dim
        self.dim_head = dim_head
        self.rescale_output_factor = rescale_output_factor
        self.attn_drop = attn_drop
        self.residual_connection = residual_connection

        self.norm_q = get_norm_layer(qk_norm, dim_head, heads=heads, eps=eps, elementwise_affine=elementwise_affine)
        self.norm_k = get_norm_layer(qk_norm, dim_head, heads=heads, eps=eps, elementwise_affine=elementwise_affine)
        self.norm_cross = get_norm_layer(cross_attention_norm if is_cross_attention else None, cross_attention_dim, eps=eps)

        self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)

        self.out_proj = nn.Linear(inner_dim, out_dim, bias=out_bias)
        self.out_drop = nn.Dropout(dropout)

        self._cache_attn_mask = None

    def forward(
        self,
        hidden_states: torch.Tensor,
        cross_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states=None,  # For compatibility with transformers
    ) -> torch.Tensor:
        assert encoder_hidden_states is None, "encoder_hidden_states should be None"
        residual = hidden_states

        input_ndim = hidden_states.ndim
        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")

        batch_size, sequence_length, _ = hidden_states.shape if cross_hidden_states is None else cross_hidden_states.shape

        if not self.training and self._cache_attn_mask is not None:
            # If eval and have cached attention mask, use it
            attention_mask = self._cache_attn_mask
        elif attention_mask is not None:
            attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1])

        query = self.to_q(hidden_states)

        if cross_hidden_states is None:
            cross_hidden_states = hidden_states
        elif self.norm_cross:
            cross_hidden_states = self.norm_cross_hidden_states(cross_hidden_states)

        key = self.to_k(cross_hidden_states)
        value = self.to_v(cross_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // self.heads

        query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)

        if self.norm_q is not None:
            query = self.norm_q(query)
        if self.norm_k is not None:
            key = self.norm_k(key)

        hidden_states = self._process_attn(query, key, value, attn_mask=attention_mask)

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        hidden_states = self.out_proj(hidden_states)
        hidden_states = self.out_drop(hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if self.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / self.rescale_output_factor

        return hidden_states

    def _process_attn(self, query, key, value, attn_mask):
        return F.scaled_dot_product_attention(  # pylint: disable=not-callable
            query, key, value, attn_mask=attn_mask, dropout_p=self.attn_drop, is_causal=self.is_causal
        )

    def prepare_attention_mask(self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3) -> torch.Tensor:
        r"""
        Prepare the attention mask for the attention computation.

        Args:
            attention_mask (`torch.Tensor`):
                The attention mask to prepare.
            target_length (`int`):
                The target length of the attention mask. This is the length of the attention mask after padding.
            batch_size (`int`):
                The batch size, which is used to repeat the attention mask.
            out_dim (`int`, *optional*, defaults to `3`):
                The output dimension of the attention mask. Can be either `3` or `4`.

        Returns:
            `torch.Tensor`: The prepared attention mask.
        """
        head_size = self.heads
        if attention_mask is None:
            return attention_mask

        current_length: int = attention_mask.shape[-1]
        if current_length != target_length:
            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

        if out_dim == 3:
            if attention_mask.shape[0] < batch_size * head_size:
                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
        elif out_dim == 4:
            attention_mask = attention_mask.unsqueeze(1)
            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)

        return attention_mask


######################
### RoPE Attention ###
######################


def init_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
    freqs_x = []
    freqs_y = []
    mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    for _ in range(num_heads):
        angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
        fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi / 2 + angles)], dim=-1)
        fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi / 2 + angles)], dim=-1)
        freqs_x.append(fx)
        freqs_y.append(fy)
    freqs_x = torch.stack(freqs_x, dim=0)
    freqs_y = torch.stack(freqs_y, dim=0)
    freqs = torch.stack([freqs_x, freqs_y], dim=0)
    return freqs


def init_t_xy(end_x: int, end_y: int):
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode="floor").float()
    return t_x, t_y


def compute_mixed_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int):
    N = t_x.shape[0]
    # No float 16 for this range
    with torch.cuda.amp.autocast(enabled=False):
        freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)).view(N, num_heads, -1).permute(1, 0, 2)
        freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)).view(N, num_heads, -1).permute(1, 0, 2)
        freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)
    return freqs_cis


def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0):
    freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))

    t_x, t_y = init_t_xy(end_x, end_y)
    freqs_x = torch.outer(t_x, freqs_x)
    freqs_y = torch.outer(t_y, freqs_y)
    freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
    freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
    return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]
    else:
        assert False
    return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)


class RoPEAttention(Attention):
    """Multi-head Attention block with rotary position embeddings."""

    def __init__(self, *args, rope_theta=10.0, rope_mixed=True, grid_size=14, cls_token=False, **kwargs):
        super().__init__(*args, **kwargs)

        self.rope_mixed = rope_mixed
        self.cls_token = cls_token

        if self.rope_mixed:
            self.compute_cis = partial(compute_mixed_cis, num_heads=self.heads)

            freqs = init_2d_freqs(dim=self.dim_head, num_heads=self.heads, theta=rope_theta, rotate=True).view(2, -1)
            self.freqs = nn.Parameter(freqs, requires_grad=True)

            t_x, t_y = init_t_xy(end_x=grid_size, end_y=grid_size)
            self.register_buffer("freqs_t_x", t_x, persistent=False)
            self.register_buffer("freqs_t_y", t_y, persistent=False)

            # No need to copy using EMA
            self.freqs_t_x._ema = False
            self.freqs_t_y._ema = False
        else:
            self.compute_cis = partial(compute_axial_cis, dim=self.dim_head, theta=rope_theta)
            freqs_cis = self.compute_cis(end_x=grid_size, end_y=grid_size)
            self.freqs_cis = freqs_cis

    def _process_attn(self, q, k, v, attn_mask):
        _, _, N, _ = q.shape
        n_2dtoken = N - 1 if self.cls_token else N

        ###### Apply rotary position embedding
        w = h = math.sqrt(n_2dtoken)
        if self.rope_mixed:
            t_x, t_y = self.freqs_t_x, self.freqs_t_y
            if self.freqs_t_x.shape[0] != n_2dtoken:
                t_x, t_y = init_t_xy(end_x=w, end_y=h)
                t_x, t_y = t_x.to(q.device), t_y.to(q.device)
            freqs_cis = self.compute_cis(self.freqs, t_x, t_y)
        else:
            freqs_cis = self.freqs_cis
            if self.freqs_cis.shape[0] != n_2dtoken:
                freqs_cis = self.compute_cis(end_x=w, end_y=h)
            freqs_cis = freqs_cis.to(q.device)

        if self.cls_token:
            q_rope, k_rope = apply_rotary_emb(q[:, :, 1:], k[:, :, 1:], freqs_cis=freqs_cis)
            q = torch.cat((q[:, :, :1], q_rope), dim=2)
            k = torch.cat((k[:, :, :1], k_rope), dim=2)
        else:
            q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

        return super()._process_attn(q, k, v, attn_mask)
