from __future__ import annotations

import torch
from torch import Tensor


def _rotate_half(x: Tensor) -> Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).reshape(x.shape)


class RotaryEmbedding:
    """Minimal rotary positional embedding (RoPE).

    Applies a complex rotation to query/key vectors after projection and before attention.

    Parameters
    ----------
    dim : int
        Head dimension per attention head (must be even)

    theta : float, default=30000.0
        RoPE base controlling rotation frequency. Larger values = slower rotation
    """

    def __init__(self, dim: int, theta: float = 30000.0) -> None:
        if dim % 2 != 0:
            raise ValueError("Rotary head dimension must be even")
        self.dim = dim
        self.theta = float(theta)

    def _cos_sin(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]:
        half_dim = self.dim // 2
        idx = torch.arange(0, half_dim, device=device, dtype=dtype)
        inv_freq = 1.0 / (self.theta ** (idx / half_dim))
        t = torch.arange(seq_len, device=device, dtype=dtype)
        freqs = torch.outer(t, inv_freq)  # (seq_len, dim/2)
        cos = torch.cos(freqs).repeat_interleave(2, dim=-1)
        sin = torch.sin(freqs).repeat_interleave(2, dim=-1)
        return cos, sin

    def rotate_queries_or_keys(self, t: Tensor) -> Tensor:
        """Apply rotary embedding to a tensor shaped (..., n_heads, seq_len, head_dim)."""
        if t.shape[-1] != self.dim:
            raise ValueError(f"Expected last dim {self.dim}, got {t.shape[-1]}")

        seq_len = t.shape[-2]
        device, dtype = t.device, t.dtype
        cos, sin = self._cos_sin(seq_len, device, dtype)
        view_shape = (1,) * (t.ndim - 2) + (seq_len, self.dim)
        cos = cos.view(*view_shape)
        sin = sin.view(*view_shape)
        return (t * cos) + (_rotate_half(t) * sin)

