import torch
import torch.nn as nn
from model_registry import register_model, MODEL_REGISTRY
from math import pi
from base.PE_Base import PE_Base

class Learned_Axial_RoPE(PE_Base):
  def __init__(self, embedding_dim, P_x, P_y):
    super(Learned_Axial_RoPE, self).__init__(embedding_dim, P_x, P_y)
    self.embedding_dim = embedding_dim
    D = embedding_dim
    d_size = D // 4

    self.freq_x = torch.nn.Parameter(torch.randn(1, d_size)*2*pi, requires_grad=True)
    self.freq_y = torch.nn.Parameter(torch.randn(1, d_size)*2*pi, requires_grad=True)

  def _init_sin_cos_(self, P_x, P_y, D, init_x=0, init_y=0, uniform_freq=False):
    assert D % 4 == 0, "D must be divisible by 4"

    # Use arange directly with dtype for potential speedup

    # Compute rotation angles using matrix multiplication
    theta_x = self.i @ self.freq_x  # Shape: [P_x, D//4]
    theta_y = self.j @ self.freq_y  # Shape: [P_y, D//4]

    # Compute sin and cos together
    sin_x, cos_x = torch.sin(theta_x), torch.cos(theta_x)
    sin_y, cos_y = torch.sin(theta_y), torch.cos(theta_y)
    sin_x = sin_x.view(1, 1, P_x, 1, -1)  # [1,1,P_x,1,D//4]
    cos_x = cos_x.view(1, 1, P_x, 1, -1)
    sin_y = sin_y.view(1, 1, 1, P_y, -1)
    cos_y = cos_y.view(1, 1, 1, P_y, -1)
    return sin_x, cos_x, sin_y, cos_y


  def apply_rope(self, z):
    B, N, P, D = z.shape
    P_x, P_y = self.P_x, self.P_y
    assert P == P_x * P_y, "P must equal P_x * P_y"

    z = z.view(B, N, P_x, P_y, D)
    x = z[..., D//2:]
    y = z[..., :D//2]

    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    y1 = y[..., ::2]
    y2 = y[..., 1::2]

    # Use the learned frequencies
    sin_x, cos_x, sin_y, cos_y = self._init_sin_cos_(P_x, P_y, D)
    x_rot = torch.stack([x1 * cos_x - x2 * sin_x, x1 * sin_x + x2 * cos_x], dim=-1)
    y_rot = torch.stack([y1 * cos_y - y2 * sin_y, y1 * sin_y + y2 * cos_y], dim=-1)

    z_rot = torch.cat([y_rot, x_rot], dim=-1)  # [B, N, P_x, P_y, D]
    z_rot = z_rot.view(B, N, P, D)

    return z_rot
  
register_model(
        "Learned Axial RoPE",
        {"PE_method" : Learned_Axial_RoPE, 
         "stem_only" : False, 
         "shared_pe" : False, 
         "rot_x" : False,
         "rot_value" : False
         }
    )