import torch
from math import pi, sqrt
from Positional_Embeddings.PE_registry import PE_REGISTRY, register_PE
from model_registry import register_model, MODEL_REGISTRY


class Mixed_RoPE(torch.nn.Module):
  def __init__(self, embedding_dim, P_x, P_y, uniform_freq=False):
    super(Mixed_RoPE, self).__init__()
    self.embedding_dim = embedding_dim
    self.og_P_x = P_x # number of patches in the x direction
    self.og_P_y = P_y # number of patches in the y direction
    self.P_x = P_x # number of patches in the x direction
    self.P_y = P_y # number of patches in the y direction

    D = embedding_dim

    self.frequency_x = torch.nn.Parameter(torch.randn(1, int(D/2)) * 2 * torch.pi)
    self.frequency_y = torch.nn.Parameter(torch.randn(1, int(D/2)) * 2 * torch.pi)
    self.register_buffer('i',torch.linspace(-torch.pi, torch.pi, P_x).unsqueeze(1).float()) # X indices, i
    self.register_buffer('j',torch.linspace(-torch.pi, torch.pi, P_y).unsqueeze(1).float()) # Y indices, j
    self.train_mode()


  def forward(self, x):
    # Error if not in train mode
    raise ValueError("forward() called while mode not specified")
  
  def forward_extapolate(self, x):
    pass
  
  def set_Patches(self, P_x, P_y):
    device = self.i.device
    min_x = -pi*P_x/self.og_P_x
    max_x = pi*P_x/self.og_P_x
    min_y = -pi*P_y/self.og_P_y
    max_y = pi*P_y/self.og_P_y
    self.P_x = P_x
    self.P_y = P_y

    self.i = torch.linspace(min_x, max_x, P_x).unsqueeze(1).float().to(device) # X indices, i
    self.j = torch.linspace(min_y, max_y, P_y).unsqueeze(1).float().to(device) # Y indices, j
    self._set_device_(device)

  def forward_train(self, x):
    return self.apply_rope(x)

  def _set_device_(self, device):
    self.frequency_x = self.frequency_x.to(device)
    self.frequency_y = self.frequency_y.to(device)

  def extrapolate_mode(self):
    # self.forward = self.forward_extapolate
    pass

  def train_mode(self):
    self.forward = self.forward_train

  def apply_rope(self, z):
    B, N, P, D = z.shape

    P_x, P_y = self.P_x, self.P_y
    assert P == P_x * P_y, "P must equal P_x * P_y"
    assert D%2 == 0
    D2 = D//2
    z = z.view(B, N, P_x, P_y, D2, 2)  # Treat last dim as 3D vectors

    # Split vector components
    x, y = z.unbind(dim=-1)  # each of shape [B, N, P_x, P_y, D3]

    sin, cos = self._init_sin_cos(self.P_x, self.P_y, D)
    # Apply rotation
    x_ = x * cos - y * sin
    y_ = x * sin + y * cos
    # Stack back into rotated vector
    z_rot = torch.stack([x_, y_], dim=-1)  # [B, N, P_x, P_y, D2, 2]
    z_rot = z_rot.view(B, N, P, D)

    return z_rot

    return z_rot
  def _init_sin_cos(self, P_x,P_y, D):
    # P_x : Number of patches in the X direction
    # P_y : Number of patches in the Y direction
    # D : Hidden/Embedding dimensions

    assert D % 2 == 0, "D must be divisible by 3"

    

    thetas_x = torch.matmul(self.i, self.frequency_x) # P_x x D/2
    thetas_y = torch.matmul(self.j, self.frequency_y) # P_y x D/2
    thetas_x = thetas_x.view(1, 1, 1, P_x, -1)
    thetas_y = thetas_y.view(1, 1, P_y, 1, -1)
    thetas = thetas_x + thetas_y

    return torch.sin(thetas), torch.cos(thetas)
  
register_model(
        "Mixed RoPE",
        {"PE_method" : Mixed_RoPE, 
         "stem_only" : False, 
         "shared_pe" : False, 
         "rot_x" : False,
         "rot_value" : False
         }
    )