import torch
from math import pi, sqrt
from Positional_Embeddings.PE_registry import PE_REGISTRY, register_PE
from model_registry import register_model, MODEL_REGISTRY
from base.PE_Base import PE_Base


class Axial_RoPE(PE_Base):
  def __init__(self, embedding_dim, P_x, P_y, uniform_freq=False):
    super(Axial_RoPE, self).__init__(embedding_dim, P_x, P_y)
    D = embedding_dim
    d_size = D // 4
    d = torch.arange(d_size, dtype=torch.float32).unsqueeze(0)

    if uniform_freq:
        theta_d = torch.full(d.shape, 1 / pi)
    else:
        theta_d = 100 ** (-2 * d / D)
    self.register_buffer('theta_d', theta_d)

    self._init_sin_cos_(P_x, P_y, D, uniform_freq=uniform_freq)
    self.uniform_freq = uniform_freq
  
  def set_Patches(self, P_x, P_y):
    super().set_Patches(P_x, P_y)
    self._init_sin_cos_(P_x, P_y, self.embedding_dim, uniform_freq=self.uniform_freq)


  def _set_device_(self, device):
    self.sin_x = self.sin_x.to(device)
    self.cos_x = self.cos_x.to(device)
    self.sin_y = self.sin_y.to(device)
    self.cos_y = self.cos_y.to(device)

  def apply_rope(self, z):
    '''
      Apply Axial RoPE to the input tensor z using the modified fast implementation proposed in Su et al. 
      Args:
          z: Input tensor of shape [B, N, P_x*P_y, D], where B is the batch size,
             N is the number of tokens, P_x/y is the number of horizontal/vertical patches, 
             and D is the number of quadruples.
      Returns:
          z_rot: Rotated tensor of the same shape as z.
      '''
    B, N, P, D = z.shape
    P_x, P_y = self.P_x, self.P_y
    assert P == P_x * P_y, "P must equal P_x * P_y"

    z = z.view(B, N, P_x, P_y, D)
    # Split in to p_x and p_y
    x = z[..., D//2:]
    y = z[..., :D//2]

    # Get odd and even indices
    x1 = x[..., ::2]
    x2 = x[..., 1::2]

    y1 = y[..., ::2]
    y2 = y[..., 1::2]

    # Apply rotation
    x_rot = torch.stack([x1 * self.cos_x - x2 * self.sin_x, x1 * self.sin_x + x2 * self.cos_x], dim=-1)
    y_rot = torch.stack([y1 * self.cos_y - y2 * self.sin_y, y1 * self.sin_y + y2 * self.cos_y], dim=-1)

    z_rot = torch.cat([y_rot, x_rot], dim=-1)  # [B, N, P_x, P_y, D]
    z_rot = z_rot.view(B, N, P, D)

    return z_rot
    
  def _init_sin_cos_(self, P_x, P_y, D, init_x=0, init_y=0, uniform_freq=False):
    assert D % 4 == 0, "D must be divisible by 4"

    # Use arange directly with dtype for potential speedup

    theta_d = self.theta_d

    # Compute rotation angles using matrix multiplication
    theta_x = self.i @ theta_d  # Shape: [P_x, D//4]
    theta_y = self.j @ theta_d  # Shape: [P_y, D//4]

    # Compute sin and cos together
    sin_x, cos_x = torch.sin(theta_x), torch.cos(theta_x)
    sin_y, cos_y = torch.sin(theta_y), torch.cos(theta_y)

    sin_x = sin_x.view(1, 1, P_x, 1, -1)  # [1,1,P_x,1,D//4]
    cos_x = cos_x.view(1, 1, P_x, 1, -1)
    sin_y = sin_y.view(1, 1, 1, P_y, -1)
    cos_y = cos_y.view(1, 1, 1, P_y, -1)
    self.sin_x = sin_x
    self.cos_x = cos_x
    self.sin_y = sin_y
    self.cos_y = cos_y

class Uniform_Axial_RoPE(Axial_RoPE):
  def __init__(self, embedding_dim, P_x, P_y):
    super().__init__(embedding_dim, P_x, P_y, uniform_freq=True)

register_model(
        "Axial RoPE",
        {"PE_method" : Axial_RoPE, 
         "stem_only" : False, 
         "shared_pe" : False, 
         "rot_x" : False,
         "rot_value" : False
         }
    )

register_model(
        "Uniform Axial RoPE",
        {"PE_method" : Uniform_Axial_RoPE, 
         "stem_only" : False, 
         "shared_pe" : True, 
         "rot_x" : False,
         "rot_value" : False
         }
    )
register_model(
        "Patch RoPE",
        {"PE_method" : Axial_RoPE, 
         "stem_only" : False, 
         "shared_pe" : True, 
         "rot_x" : True,
         "rot_value" : False
         }
    )