library(torchvision)
library(torchdatasets)
library(torch)
library(luz)

require(MASS)

library(rtracklayer)
library(plyranges)

source('mercat_helper_fns.R')

compute_loss_geo_torch <- function(Y, Y_p, X, X_p)
{
  ## compute the normal to the planes cutting through 2 points on the sphere and sphere center
  Y_n <- torch_cross(Y, torch_unsqueeze(Y_p, 1), dim=2)
  ## compute angles on sphere geodesics based on angle between these normals (which is the angle between the two planes)
  Y_n_norm <- torch_norm(Y_n, dim=2)
  angle_Y <- torch_div(torch_matmul(Y_n, torch_transpose(Y_n,1,2)), torch_outer(Y_n_norm, Y_n_norm) )
  X_Xp <- torch_sub(X,X_p)
  X_Xp_norm <- torch_norm(X_Xp, dim=2)
  angle_X <- torch_div(torch_matmul(X_Xp, torch_transpose(X_Xp,1,2)), torch_outer(X_Xp_norm,X_Xp_norm))
  torch_sqrt(torch_sum(torch_square(angle_Y - angle_X)) / ((torch_prod(angle_Y$shape) - angle_Y$shape[1])/2))
}


compute_loss_full <- function(X, Y, angle_subsample)
{
  loss <- torch_tensor(0)
  for (k in 1:nrow(Y))
  {
    if (k > 1)
    {
      if (k < nrow(Y))
      {
        sub_idcs <- c(1:(k-1), (k+1):nrow(Y))
      } else {
        sub_idcs <- 1:(k-1)
      }
    } else {
      sub_idcs <- (k+1):nrow(Y)
    }
    if (!is.null(angle_subsample))
    {
      sub_idcs <- sample(sub_idcs, min(angle_subsample,length(sub_idcs)))
    }
    if (length(sub_idcs) > 1)
    {
      loss <- loss + compute_loss_geo_torch(Y[sub_idcs,], Y[k,], X[sub_idcs,], X[k,])
    }
  }
  return(1/nrow(Y)*loss)
}

compute_loss_batch <- function(X, Y, slice_idcs, angle_subsample)
{
  loss <- torch_tensor(0)
  for (k in slice_idcs)
  {
    if (k > 1)
    {
      if (k < nrow(Y))
      {
        sub_idcs <- c(1:(k-1), (k+1):nrow(Y))
      } else {
        sub_idcs <- 1:(k-1)
      }
    } else {
      sub_idcs <- (k+1):nrow(Y)
    }
    if (!is.null(angle_subsample))
    {
      sub_idcs <- sample(sub_idcs, min(angle_subsample,length(sub_idcs)))
    }
    if (length(sub_idcs) > 1)
    {
      loss <- loss + compute_loss_geo_torch(Y[sub_idcs,], Y[k,], X[sub_idcs,], X[k,])
    }
  }
  return(1/length(slice_idcs)*loss)
}

compute_angles <- function(X, angle_subsample, seeded=T, geodesic=F)
{
  if(is.null(angle_subsample))
  {
    dim_ang <- nrow(X)-1
  } else {
    dim_ang <- angle_subsample
  }
  X_angles <- torch_tensor(array(0, dim=c(nrow(X), dim_ang, dim_ang)))
  for (k in 1:nrow(X))
  {
    if (k > 1)
    {
      if (k < nrow(X))
      {
        sub_idcs <- c(1:(k-1), (k+1):nrow(X))
      } else {
        sub_idcs <- 1:(k-1)
      }
    } else {
      sub_idcs <- (k+1):nrow(X)
    }
    if (!is.null(angle_subsample) & length(sub_idcs) > 1)
    {
      ## To ensure consistent draws when evaluating angles comparing X and Y 
      if (seeded)
      {
        set.seed(k)
      }
      sub_idcs <- sample(sub_idcs, min(angle_subsample,length(sub_idcs)))
    }
    if (geodesic)
    {
      X_n <- torch_cross(X[sub_idcs,], torch_unsqueeze(X[k,], 1), dim=2)
      ## compute angles on sphere geodesics based on angle between these normals (which is the angle between the two planes)
      X_n_norm <- torch_norm(X_n, dim=2)
      X_angles[k,,] <- torch_div(torch_matmul(X_n, torch_transpose(X_n,1,2)), torch_outer(X_n_norm, X_n_norm) )
    } else {
      X_Xp <- torch_sub(X[sub_idcs,],X[k,])
      X_angles[k,,] <- torch_div(torch_matmul(X_Xp, torch_transpose(X_Xp,1,2)),torch_outer(torch_norm(X_Xp, dim=2),torch_norm(X_Xp, dim=2)) )
    }
  }
  return(X_angles)
}



########################################
##### Routine to compute embedding #####
########################################
mercat <- function(X, iter=100, lr=0.01, stepsize=50,
                   report_progress=FALSE, X_redPCA=TRUE,
                   batch_size=64, angle_subsample=NULL,
                   Y_prior=NULL, shuffle=TRUE)
{
  
  ## Get X down to 50 PCs
  X_PCA <- prcomp(X, rank. = 50)
  if (X_redPCA && ncol(X) > 50)
  {
    X <- X_PCA$x
  }
  X <- torch_tensor(X, requires_grad = F)
  if (is.null(Y_prior))
  {
    ## Initialize with 2 leading PCs, wrap that around a unit half-sphere (very roughly, it's just an initial guess)
    ## keep it a bit away from the poles.
    Y <- X_PCA$x[,1:2]
    Y_max <- max(Y)
    Y_min <- min(Y)
    angles <- torch_tensor((Y - Y_min)/(Y_max - Y_min)*.6*pi + .2*pi, requires_grad = T)
    Y <- RtoY(angles)
  } else {
    Y <- Y_prior
    angles <- torch_tensor(YtoR_plain(Y), requires_grad = T)
    Y <- RtoY(angles)
  }
  
  print('Computed initial PCA and initialization.')
  
  if (report_progress)
  {
    Y_init <- as_array(Y$detach())
    Y_prog <- vector("list", length = iter)
    R_prog <- vector("list", length = iter)
    loss_prog <- rep(0, iter)
  }
  optimizer <- optim_adam(angles, lr=lr)
  lr_lambda_fn <- function(epoch)
  {
    if (length(stepsize) > 1)
    {
      step <- sum(stepsize <= epoch)
      return(.1 ** step)
    } else {
      step <- (epoch %/% stepsize)
      return(.1 ** step)
    }
  }
  scheduler <- lr_lambda(optimizer, lr_lambda_fn)
  for (it in 1:iter)
  {
    if (it %% 100 == 0)
    {
      print(paste0('Epoch ', it))
    }
    if (shuffle)
    {
      subsample_idcs <- sample(nrow(X))
    } else {
      subsample_idcs <- 1:nrow(X)
    }
    ## Batched version
    if (!is.null(batch_size))
    {
      comb_loss <- 0
      for (slice_idx in seq(1, nrow(X), batch_size))
      {
        if (slice_idx+batch_size < nrow(X))
        {
          slice_idcs <- slice_idx:(slice_idx+batch_size-1)
        } else {
          slice_idcs <- slice_idx:nrow(X)
        }
        optimizer$zero_grad()
        loss <- compute_loss_batch(X[subsample_idcs,], Y[subsample_idcs,], slice_idcs, angle_subsample)
        loss$backward()
        optimizer$step()
        with_no_grad({
          angle_cut <- angle_bounding(as_array(angles))
          angles[,1] <- angle_cut[,1]
          angles[,2] <- angle_cut[,2]
        })
        Y <- RtoY(angles)
        comb_loss <- comb_loss + as_array(loss)
      }
    } else {
      ## Unbatched version:
      optimizer$zero_grad()
      loss <- compute_loss_full(X[subsample_idcs,], Y[subsample_idcs,], angle_subsample)
      loss$backward()
      optimizer$step()
      with_no_grad({
        angle_cut <- angle_bounding(as_array(angles))
        angles[,1] <- angle_cut[,1]
        angles[,2] <- angle_cut[,2]
      })
      Y <- RtoY(angles)
    }
    scheduler$step()
    if (report_progress)
    {
      if (!is.null(batch_size))
      {
        loss_prog[it] <- comb_loss
      } else {
        loss_prog[it] <- as_array(loss)
      }
      Y_prog[[it]] <- as_array(torch_clone(Y$detach()))
      R_prog[[it]] <- as_array(torch_clone(angles$detach()))
    }
  }
  if(report_progress)
  {
    return(list(loss=loss_prog,
                Y=Y_prog, R=R_prog,
                Y_init=Y_init
    ))
  } else {
    return(as_array(Y$detach()))
  }
}


