import torch
from model_registry import register_model, MODEL_REGISTRY
from torch import pi
from base.PE_Base import PE_Base

class Learned_Spherical_RoPE(PE_Base):
  def __init__(self, embedding_dim, P_x, P_y):
    super(Learned_Spherical_RoPE, self).__init__(embedding_dim, P_x, P_y)

    D = embedding_dim
    self.frequency_x = torch.nn.Parameter(torch.randn(1, int(D/3)) * 2 * torch.pi)
    self.frequency_y = torch.nn.Parameter(torch.randn(1, int(D/3)) * 2 * torch.pi)
  
  def apply_rope(self, z):
    B, N, P, D = z.shape

    P_x, P_y = self.P_x, self.P_y
    assert P == P_x * P_y, "P must equal P_x * P_y"
    assert D%3 == 0
    D3 = D//3
    z = z.view(B, N, P_x, P_y, D3, 3)  # Treat last dim as 3D vectors

    # Split vector components
    x, y, z_ = z.unbind(dim=-1)  # each of shape [B, N, P_x, P_y, D3]

    sin_x, cos_x, sin_y, cos_y = self._init_sin_cos(self.P_x, self.P_y, D)

    # Apply rotation (example: yaw around y, roll around x)
    # First rotate around y-axis (yaw): affects x and z
    x_yaw = x * cos_y - z_ * sin_y
    z_yaw = x * sin_y + z_ * cos_y

    # Then rotate around x-axis (roll): affects y and z
    y_roll = y * cos_x - z_yaw * sin_x
    z_roll = y * sin_x + z_yaw * cos_x

    # Stack back into rotated vector
    z_rot = torch.stack([x_yaw, y_roll, z_roll], dim=-1)  # [B, N, P_x, P_y, D3, 3]
    z_rot = z_rot.view(B, N, P, D)

    return z_rot
    
  def _init_sin_cos(self, P_x,P_y, D):
    # P_x : Number of patches in the X direction
    # P_y : Number of patches in the Y direction
    # D : Hidden/Embedding dimensions

    assert D % 3 == 0, "D must be divisible by 3"

    thetas_x = torch.matmul(self.i, self.frequency_x) # P_x x D/3
    thetas_x = thetas_x.flatten()
    
    cos_thetax = torch.cos(thetas_x)
    sin_thetax = torch.sin(thetas_x)
    cos_x = cos_thetax
    sin_x = sin_thetax
    sin_x = sin_x.view(1, 1, P_x, 1, -1)
    cos_x = cos_x.view(1, 1, P_x, 1, -1)

    thetas_y = torch.matmul(self.j, self.frequency_y) # P_y x D/3
    thetas_y = thetas_y.flatten()
    cos_thetay = torch.cos(thetas_y)
    sin_thetay = torch.sin(thetas_y)

    cos_y = cos_thetay
    sin_y = sin_thetay
    sin_y = sin_y.view(1, 1, 1, P_y, -1)
    cos_y = cos_y.view(1, 1, 1, P_y, -1)

    return sin_x, cos_x, sin_y, cos_y
  

register_model(
      "Learned Spherical RoPE",
      {"PE_method" : Learned_Spherical_RoPE, 
        "stem_only" : False, 
        "shared_pe" : False, 
        "rot_x" : False,
        "rot_value" : False
        }
  )