import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class RotaryPositionEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE) implementation for improved relative position awareness.
    """
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        # Precompute cosine and sine caches for rotation
        self.register_buffer("cos_cached", self._get_cos_cached(max_seq_len, dim))
        self.register_buffer("sin_cached", self._get_sin_cached(max_seq_len, dim))

    def _get_cos_cached(self, seq_len, dim):
        """Precompute cosine values for each position and dimension."""
        half_dim = dim // 2
        positions = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)  # [seq_len,1]
        freqs = torch.exp(
            -torch.arange(half_dim, dtype=torch.float) * (math.log(10000.0) / half_dim)
        ).unsqueeze(0)  # [1,half_dim]
        theta = positions * freqs  # [seq_len, half_dim]
        return torch.cos(theta)

    def _get_sin_cached(self, seq_len, dim):
        """Precompute sine values for each position and dimension."""
        half_dim = dim // 2
        positions = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
        freqs = torch.exp(
            -torch.arange(half_dim, dtype=torch.float) * (math.log(10000.0) / half_dim)
        ).unsqueeze(0)
        theta = positions * freqs
        return torch.sin(theta)

    def forward(self, x, seq_dim=1):
        """
        Apply rotational position encoding to the input tensor.

        Args:
            x: Tensor of shape [batch_size, seq_len, hidden_dim]
            seq_dim: Index of the sequence dimension (default: 1)
        Returns:
            Tensor with RoPE applied to the first 'dim' dimensions.
        """
        seq_len = x.size(seq_dim)
        # Select appropriate cosine and sine
        if seq_len > self.max_seq_len:
            cos = self._get_cos_cached(seq_len, self.dim).to(x.device)
            sin = self._get_sin_cached(seq_len, self.dim).to(x.device)
        else:
            cos = self.cos_cached[:seq_len]
            sin = self.sin_cached[:seq_len]

        half_dim = self.dim // 2
        x_even = x[..., :half_dim]
        x_odd = x[..., half_dim:2*half_dim]
        # Rotate even/odd parts
        x_rot_even = x_even * cos.unsqueeze(0) - x_odd * sin.unsqueeze(0)
        x_rot_odd = x_even * sin.unsqueeze(0) + x_odd * cos.unsqueeze(0)
        x_rotated = torch.cat([x_rot_even, x_rot_odd], dim=-1)
        # Concatenate any remaining dimensions unchanged
        if x.size(-1) > self.dim:
            x_rotated = torch.cat([x_rotated, x[..., 2*half_dim:]], dim=-1)
        return x_rotated

class RotaryEmbedding(nn.Module):
    """
    Apply RoPE to query/key tensors in multi-head attention.
    """
    def __init__(self, dim, base=10000):
        super().__init__()
        assert dim % 2 == 0, "RoPE dimension must be even."
        self.dim = dim
        self.base = base

    def forward(self, x, seq_len):
        """
        Args:
            x: Tensor of shape [B, H, N, D]
            seq_len: Sequence length N
        Returns:
            Tensor with RoPE applied, same shape as input.
        """
        B, H, N, D = x.shape
        device = x.device
        # Compute positional frequencies
        pos = torch.arange(N, device=device).float()
        inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2, device=device).float() / D))
        freqs = pos.unsqueeze(1) * inv_freq.unsqueeze(0)  # [N, D/2]
        cos_val = torch.cos(freqs).unsqueeze(1).unsqueeze(1)  # [N,1,1,D/2]
        sin_val = torch.sin(freqs).unsqueeze(1).unsqueeze(1)

        # Reshape x for rotation
        x_reshaped = x.permute(2, 0, 1, 3).reshape(N, B, H, D//2, 2)
        x_even = x_reshaped[..., 0]
        x_odd = x_reshaped[..., 1]
        # Apply RoPE formulas
        x_even_out = x_even * cos_val - x_odd * sin_val
        x_odd_out = x_even * sin_val + x_odd * cos_val

        x_out = torch.stack([x_even_out, x_odd_out], dim=-1)
        x_out = x_out.reshape(N, B, H, D).permute(1, 2, 0, 3).contiguous()
        return x_out

class RoPEAttention(nn.Module):
    """
    Multi-head self-attention with rotary position embeddings for Q/K.

    Input:
      x: Tensor of shape [B, N, C]
    """
    def __init__(self, dim, num_heads=8, rope_dim=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.rope_dim = rope_dim or self.head_dim
        self.rope_q = RotaryEmbedding(self.rope_dim)
        self.rope_k = RotaryEmbedding(self.rope_dim)

    def forward(self, x, mask=None):
        B, N, C = x.shape
        # Compute Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)

        # Apply RoPE to Q and K
        if 0 < self.rope_dim <= self.head_dim:
            q_rope, q_rest = q[..., :self.rope_dim], q[..., self.rope_dim:]
            k_rope, k_rest = k[..., :self.rope_dim], k[..., self.rope_dim:]
            q = torch.cat([self.rope_q(q_rope, seq_len=N), q_rest], dim=-1)
            k = torch.cat([self.rope_k(k_rope, seq_len=N), k_rest], dim=-1)

        # Attention weights
        attn = (q @ k.transpose(-2, -1)) * self.scale
        if mask is not None:
            attn = attn + mask
        attn = torch.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)

        # Combine heads
        x_out = (attn @ v).transpose(2, 1).reshape(B, N, C)
        x_out = self.proj(x_out)
        return self.proj_drop(x_out)
