import math

import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor


class RoPE2D(nn.Module):
    """Chunk the vector into two component and apply separately RoPE to both parts based on the x
    and y coordinates of the vector.

    This implementation use the same frequencies for both the x and y coordinates, so that the
    result is invariant to symmetries. Frequencies are not learnt.

    ---
    See:
        https://www.ecva.net/papers/eccv_2024/papers_ECCV/html/1584_ECCV_2024_paper.php
        https://github.com/naver-ai/rope-vit
    """

    def __init__(self, head_dim: int, theta: float = 100.0):
        super().__init__()

        assert head_dim % 4 == 0

        freqs = torch.asarray([math.pow(theta, -i / (head_dim // 4)) for i in range(head_dim // 4)])
        self.register_buffer("freqs", freqs)

    def forward(self, x: Tensor, p: Tensor) -> Tensor:
        """Apply RoPE to the given vector x.

        ---
        Args:
            x: The vectors to rotate.
                Shape of [batch_size, n_heads, seq_len, head_dim].
            p: Coordinates of the vectors.
                Shape of [batch_size, seq_len, 2].

        ---
        Returns:
            The rotated vectors.
                Shape of [batch_size, n_heads, seq_len, head_dim].
        """
        px, py = p[:, :, 0] * 14, p[:, :, 1] * 14
        rx = torch.einsum("bs,d->bsd", px, self.freqs)
        ry = torch.einsum("bs,d->bsd", py, self.freqs)
        rx = torch.polar(torch.ones_like(rx, device=rx.device), rx)
        ry = torch.polar(torch.ones_like(ry, device=ry.device), ry)
        rx = rearrange(rx, "b s d -> b () s d")
        ry = rearrange(ry, "b s d -> b () s d")

        xx, xy = x.chunk(2, dim=-1)
        xx = torch.view_as_complex(rearrange(xx, "b h s (d c) -> b h s d c", c=2))
        xy = torch.view_as_complex(rearrange(xy, "b h s (d c) -> b h s d c", c=2))
        xx, xy = xx * rx, xy * ry
        xx = torch.view_as_real(xx).flatten(3)
        xy = torch.view_as_real(xy).flatten(3)
        return torch.cat((xx, xy), dim=3)
