import torch
import numpy as np
import math
from sklearn.decomposition import PCA



def compute_loss_geo_torch(Y, Y_p, X, X_p):

  ## compute the normal to the planes cutting through 2 points on the sphere and sphere center
  Y_n = torch.cross(Y, torch.unsqueeze(Y_p, 0), dim=1)
  ## compute angles on sphere geodesics based on angle between these normals (which is the angle between the two planes)
  Y_n_norm = torch.norm(Y_n, dim=1)
  angle_Y = torch.div(torch.matmul(Y_n, torch.transpose(Y_n,0,1)), torch.outer(Y_n_norm, Y_n_norm) )
  X_Xp = torch.sub(X,X_p)
  X_Xp_norm = torch.norm(X_Xp, dim=1)
  angle_X = torch.div(torch.matmul(X_Xp, torch.transpose(X_Xp,0,1)), torch.outer(X_Xp_norm,X_Xp_norm))
  return(torch.sqrt(torch.sum(torch.square(angle_Y - angle_X)) / ((np.prod(angle_Y.shape) - angle_Y.shape[0])/2)))


def compute_loss_full(X, Y, angle_subsample, device):

  loss = torch.tensor(0, device=device)
  for k in np.arange(Y.shape[0]):
  
    if k > 0:
    
      if k < (Y.shape[0] - 1):
      
        sub_idcs = np.concatenate((np.arange(k), np.arange((k+1),Y.shape[0])))
      else:
        sub_idcs = np.arange(k)
      
    else:
      sub_idcs = np.arange(1,Y.shape[0])
    
    if angle_subsample is not None:
    
      sub_idcs = np.random.choice(sub_idcs, min(angle_subsample, sub_idcs.shape[0]), replace=False)
    
    if sub_idcs.shape[0] > 1:
    
      loss = loss + compute_loss_geo_torch(Y[sub_idcs,:], Y[k,:], X[sub_idcs,:], X[k,:])
    
  
  return 1/Y.shape[0]*loss


def compute_loss_batch(X, Y, slice_idcs, angle_subsample, device):

  loss = torch.tensor(0, device=device)
  ## for efficient torch-based sampling
  weights = torch.ones(Y.shape[0], device=Y.device)/Y.shape[0]
  
  for k in slice_idcs:
  
    weights[k] = 0
    
    sub_idcs = torch.multinomial(weights, min(angle_subsample, Y.shape[0] - 1), replacement=False)

    loss = loss + compute_loss_geo_torch(Y[sub_idcs,:], Y[k,:], X[sub_idcs,:], X[k,:])

    weights[k] = 1/Y.shape[0]
  
  return 1/slice_idcs.shape[0]*loss




class mercat_internal(torch.nn.Module):
  
    def __init__(self, angle_init, device='cpu'):                                                               
        super(mercat_internal, self).__init__()                                                   
        self.angles = torch.nn.Parameter(angle_init)                                             
        self.Y = torch.zeros(self.angles.shape[0], 3, requires_grad=False)
        self.device = device
        
    def forward(self, X, batch_idx, angle_subsample, subsample_idcs):
      
        self.Y = (torch.tensor([0,0,1], device=self.device)*torch.cos(self.angles[:,0]).reshape(-1,1)) + \
                 (torch.tensor([1,0,0], device=self.device)*torch.sin(self.angles[:,0]).reshape(-1,1)) * \
                 (torch.tensor([1,0,0], device=self.device)*torch.cos(self.angles[:,1]).reshape(-1,1)) + \
                 (torch.tensor([0,1,0], device=self.device)*torch.sin(self.angles[:,0]).reshape(-1,1)) * \
                 (torch.tensor([0,1,0], device=self.device)*torch.sin(self.angles[:,1]).reshape(-1,1))
        
        ## Batched version
        if batch_idx is not None:
            
            loss = compute_loss_batch(X[subsample_idcs,:], self.Y[subsample_idcs,:], batch_idx, angle_subsample, self.device)
            
        else:
          ## Unbatched version:
          loss = compute_loss_full(X[subsample_idcs,:], self.Y[subsample_idcs,:], angle_subsample, self.device)
          
          
        return loss
        
    def angle_bounding(self):

      latitude = self.angles[:,0] % (2*math.pi)
      self.angles[:,1] = self.angles[:,1] % (2*math.pi)
      self.angles[:,0] = torch.where(latitude > math.pi, -latitude + 2*math.pi, latitude)





########################################
##### Routine to compute embedding #####
########################################
def mercat(X, itermax=1000, lr=0.01, stepsize=300,
                   X_redPCA=True, device='cpu',
                   batch_size=None, angle_subsample=None,
                   Y_prior=None, shuffle=True):

  
  ## Get X down to 50 PCs
  pca = PCA(n_components=min(X.shape[1], 50))
  X_PCA = pca.fit_transform(X)
  if X_redPCA and X.shape[1] > 50:
  
    X = X_PCA[:,:]
  
  X = torch.tensor(X, requires_grad = False, device=device)
  if Y_prior is None:
  
    ## Initialize with 2 leading PCs, wrap that around a unit half-sphere (very roughly, it's just an initial guess)
    ## keep it a bit away from the poles.
    Z = torch.tensor(X_PCA[:,:2], device=device)
    Z_max = torch.max(Z)
    Z_min = torch.min(Z)
    angles = (Z - Z_min)/(Z_max - Z_min)*.6*math.pi + .2*math.pi
    mercat_model = mercat_internal(angles, device=device)
    
  else:
    print('Custom Y prior currently not supported!')
    exit()
  
  
  print('Computed initial PCA and initialization.')

  optimizer = torch.optim.Adam(mercat_model.parameters(), lr=lr)
  lr_lambda_fn = lambda epoch: .1 ** np.sum(np.array(stepsize) <= epoch) if isinstance(stepsize, list) else .1 ** epoch // stepsize
  
  scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda_fn)
  for it in np.arange(itermax):
    if (it % (itermax // 10)) == 0:
    
      print('Epoch ' + str(it))
      
      
      
    if shuffle:
    
      subsample_idcs = torch.randperm(X.shape[0], device=device)
    else:
      subsample_idcs = torch.arange(X.shape[0], device=device)
      
      
    optimizer.zero_grad()
    if batch_size is not None:
      for slice_idx in np.arange(0, X.shape[0], batch_size):
      
        if slice_idx+batch_size < X.shape[0]:
        
          slice_idcs = np.arange(slice_idx,(slice_idx+batch_size))
        else:
          slice_idcs = np.arange(slice_idx,X.shape[0])
  
        loss = mercat_model(X, slice_idcs, angle_subsample, subsample_idcs)
        loss.backward()
        optimizer.step()
    else:
        loss = mercat_model(X, None, angle_subsample, subsample_idcs)
        loss.backward()
        optimizer.step()
      
    scheduler.step()
  
    with torch.no_grad():
      mercat_model.angle_bounding()
      
  return mercat_model.Y.detach().cpu().numpy()

