import torch
from model_registry import register_model, MODEL_REGISTRY
from torch import pi
from base.PE_Base import PE_Base

class Spherical_RoPE(PE_Base):
  def __init__(self, embedding_dim, P_x, P_y, xy_range='pi'):
    super(Spherical_RoPE, self).__init__(embedding_dim, P_x, P_y, grid_mode=xy_range)
    D = embedding_dim
    d = torch.arange(0, D/3).unsqueeze(0) # triplet index, d
    theta_d = 100**(-2*d/D)
    self.register_buffer('theta_d', theta_d) # [1, D/3]
    # self._init_rotation_matrices_(P_x, P_y, D)
    self._init_sin_cos(P_x, P_y, D)
    self.frequency_x = theta_d * torch.arange(0, P_x).unsqueeze(1)  # [P_x, 1]
    self.frequency_y = theta_d * torch.arange(0, P_y).unsqueeze(1)  # [P_y, 1]

  def apply_rope(self, z):
    B, N, P, D = z.shape
    P_x, P_y = self.P_x, self.P_y

    assert P == P_x * P_y, "P must equal P_x * P_y"
    assert D%3 == 0

    D3 = D//3
    z = z.view(B, N, P_x, P_y, D3, 3)  # Treat last dim as 3D vectors

    # Split vector components

    x, y, z_ = z.unbind(dim=-1)  # each of shape [B, N, P_x, P_y, D3]

    # Prepare sin/cos buffers for broadcasting

    sin_x = self.sin_x
    cos_x = self.cos_x
    sin_y = self.sin_y
    cos_y = self.cos_y

    # Apply rotation (example: yaw around y, roll around x)
    # First rotate around y-axis (yaw): affects x and z

    x_yaw = x * cos_y - z_ * sin_y
    z_yaw = x * sin_y + z_ * cos_y

    # Then rotate around x-axis (roll): affects y and z
    y_roll = y * cos_x - z_yaw * sin_x
    z_roll = y * sin_x + z_yaw * cos_x

    # Stack back into rotated vector
    z_rot = torch.stack([x_yaw, y_roll, z_roll], dim=-1)  # [B, N, P_x, P_y, D3, 3]
    z_rot = z_rot.view(B, N, P, D)

    return z_rot
    
  # def _set_device_(self, device):
  #   self.sin_x = self.sin_x.to(device)
  #   self.cos_x = self.cos_x.to(device)
  #   self.sin_y = self.sin_y.to(device)
  #   self.cos_y = self.cos_y.to(device)

  def set_Patches(self, P_x, P_y, xy_range='pi'):
    super().set_Patches(P_x, P_y, xy_range=xy_range)


  def _init_sin_cos(self, P_x, P_y, D):
    # P_x : Number of patches in the X direction
    # P_y : Number of patches in the Y direction
    # D : Hidden/Embedding dimensions

    assert D % 3 == 0, "D must be divisible by 3"

    thetas_x = torch.matmul(self.i, self.frequency_x)
    # P_x x D/3
    thetas_x = thetas_x.flatten()
    cos_thetax = torch.cos(thetas_x)
    sin_thetax = torch.sin(thetas_x)
    cos_x = cos_thetax
    sin_x = sin_thetax
    sin_x = sin_x.view(1, 1, P_x, 1, -1)
    cos_x = cos_x.view(1, 1, P_x, 1, -1)
    thetas_y = torch.matmul(self.j, self.frequency_y)
    # P_y x D/3
    thetas_y = thetas_y.flatten()
    cos_thetay = torch.cos(thetas_y)
    sin_thetay = torch.sin(thetas_y)
    cos_y = cos_thetay
    sin_y = sin_thetay
    sin_y = sin_y.view(1, 1, 1, P_y, -1)
    cos_y = cos_y.view(1, 1, 1, P_y, -1)
    # return sin_x, cos_x, sin_y, cos_y
    self.register_buffer('sin_x', sin_x)
    self.register_buffer('cos_x', cos_x)
    self.register_buffer('sin_y', sin_y)
    self.register_buffer('cos_y', cos_y)
  

register_model(
      "Spherical RoPE",
      {"PE_method" : Spherical_RoPE, 
        "stem_only" : False, 
        "shared_pe" : True, 
        "rot_x" : False,
        "rot_value" : False
        }
  )

register_model(
      "Patch Spherical RoPE",
      {"PE_method" : Spherical_RoPE, 
        "stem_only" : False, 
        "shared_pe" : True, 
        "rot_x" : True,
        "rot_value" : False
        }
  )