import torch
from math import pi, sqrt
from Positional_Embeddings.PE_registry import PE_REGISTRY, register_PE
from model_registry import register_model, MODEL_REGISTRY


class PE_Base(torch.nn.Module):
  def __init__(self, embedding_dim, P_x, P_y, uniform_freq=False, grid_mode='pi'):
    super(PE_Base, self).__init__()
    self.embedding_dim = embedding_dim

    self.og_P_x = P_x # number of patches in the x direction
    self.og_P_y = P_y # number of patches in the y direction

    self.P_x = P_x # number of patches in the x direction
    self.P_y = P_y # number of patches in the y direction

    D = embedding_dim
    self.grid_mode = grid_mode

    if grid_mode == 'pi':
        min_x, max_x = -pi, pi
        min_y, max_y = -pi, pi
    elif grid_mode == 'positive':
        min_x, max_x = 0, P_x
        min_y, max_y = 0, P_y
    elif grid_mode == 'centered':
        min_x, max_x = -P_x/2, P_x/2
        min_y, max_y = -P_y/2, P_y/2

    i = torch.linspace(min_x, max_x, P_x).unsqueeze(1).float() # X indices, i
    j = torch.linspace(min_x, max_x, P_y).unsqueeze(1).float() # Y indices, j

    self.register_buffer('i', i)
    self.register_buffer('j', j)
    
    self.train_mode()

  def forward(self, x):
    # Error if not in train mode
    raise ValueError("forward() called while mode not specified")
  
  def set_Patches(self, P_x, P_y):
    device = self.i.device
    
    grid_mode = self.grid_mode
    if grid_mode == 'pi':
        min_x, max_x = -pi, pi
        min_y, max_y = -pi, pi
    elif grid_mode == 'positive':
        min_x, max_x = 0, P_x
        min_y, max_y = 0, P_y
    elif grid_mode == 'centered':
        min_x, max_x = -P_x/2, P_x/2
        min_y, max_y = -P_y/2, P_y/2

    self.P_x = P_x
    self.P_y = P_y

    self.i = torch.linspace(min_x, max_x, P_x).unsqueeze(1).float().to(device) # X indices, i
    self.j = torch.linspace(min_y, max_y, P_y).unsqueeze(1).float().to(device) # Y indices, j
    self._set_device_(device)
    
  def forward_train(self, z):
    # Apply PE to vector to the input tensor
    return self.apply_rope(z)

  def _set_device_(self, device):
    # Make sure all buffers are sent to the device if needed
    pass

  def extrapolate_mode(self):
    # Deprecated : For evaluating on different patch sizes if needed
    pass

  def train_mode(self):
    self.forward = self.forward_train

  def apply_rope(self, z):
    # Error if not defined by subclass
    raise ValueError("apply_rope() called while not specified")