
inner_l <- function(u, v){
   - sum(u[1] * v[1]) + sum(u[-1] * v[-1])
}

norm_l <- function(u, sq = FALSE){
  if (sq){
    inner_l(u, u)
  } else{
    sqrt(inner_l(u, u))
  }
}


#Riemanninan Exponential
exp_fn <- function(u, p){
  u_norm <- norm_l(u)
  if (u_norm == 0){
    return(p)
  } else{
    cosh(u_norm) * p + sinh(u_norm) * u / u_norm 
  }
}

#Riemannian Logarithm
log_fn <- function(q, p){
  inner <- inner_l(p, q)
  
  # print(paste0("inner: ", inner))
  
  if( (inner + 1)^2 < 1e-10){
    return( 0 * p )
  } else{
    acosh(-inner) * (q + inner * p) /sqrt(inner^2 - 1) 
  }
}

#Riemannian distance
dist_fn <- function(p, q){
  # print("dist_fn troubleshoot")
  # print(p)
  # print(q)
  # print(-inner_l(p, q))
  # print( -inner_l(p,q) < 1)
  # print( (inner_l(p,q) + 1)^2 < 1e-10)
  
  inner <- inner_l(p, q)
  if( (inner + 1)^2 < 1e-10 ){
    return(0)
  } else{
    acosh(- inner_l(p, q)) 
  }
}

#isomorphsim from tangent space to euclidean space
iso_fn <- function(u){
  return(c(0, u))
} 


rBall_single <- function(p, r){
  dist <- r + 1
  while(dist >= r){
    samp <- rEWGauss(1, p, p, 2)
    dist <- dist_fn(p, samp)
  }
  return(samp)
}


rBall <- function(n, p, r){
  t(sapply(1:n, function(x) rBall_single(p, r)))
}


rradius <- function(r, d){
  accept <- FALSE
  
  while(!accept){
    rho <- runif(1, 0, r)
    u <- runif(1, 0, 1)  
    accept <- u <= (sinh(rho)/sinh(r))^(d - 1)
  }
  return(rho)
}

rBall_uniform <- function(n, p, r){
  d <- length(p) - 1
  
  direction <- matrix(rnorm(n * d, 0, 1), n, d)
  direction <- t(apply(direction, 1, function(x){
    x / sqrt(sum(x^2))
  }))
  direction <- cbind(rep(0, n), direction)
  
  rad <- sapply(1:n, function(x) rradius(r, d))
  samp <- t(sapply(1:n, function(x){
    cosh(rad[x]) * p + sinh(rad[x]) * direction[x, ]
  }))
  return(samp)
}
# 
# p <- c(1, 0, 0)
# test <- rBall_uniform(200, c(1, 0, 0), 1.5)
# mean <- frechet_mean(test, n_iter = 10000, lambda = 1e-10, result = TRUE)
# test_plot <- scatterplot3d::scatterplot3d(test)
# test_plot$points3d(p[1], p[2], p[3], col = "red")
# test_plot$points3d(mean[1], mean[2], mean[3], col = "blue")


paral_trans <- function(u, p, q){
  inner <- inner_l(p, q)
  if( (inner + 1)^2 < 1e-10 ){
    return(u)
  } else{
  u + inner_l(q + inner * p, u) / (1 - inner) * (p + q)
  }
}



rEWGauss <- function(n, p, eta, sig){
  
  d <- length(p) - 1
  samp <- t( sapply(1:n, function(x){
    eta_tang <- log_fn(eta, p)[-1]
    tang <- c(0, rnorm(d, eta_tang, sig))
    exp_fn(tang, p)
  }) )
  
  return(samp)
}

# p <- c(1, 0, 0)
# test <- rBall(10, c(1, 0, 0), 1.5)
# mean <- frechet_mean(test, n_iter = 10000, lambda = 1e-10, result = TRUE)
# eta <- c(1, 0, 0)
# test_2 <- rEWGauss(100, p, eta, 0.5)
# test_plot <- scatterplot3d::scatterplot3d(test_2)
# test_plot$points3d(eta[1], eta[2], eta[3], col = "red")




# ======================== sampling from Riemannian Laplace distributions ========================

rRLap_single <- function(eta, sig){
  u <- 1
  tau <- 0.5
  iter <- 0
  status <- FALSE
  d <- length(eta) - 1
  
  
  # draw initial sample
  
  v_init <- c(0, rnorm(d, 0, sig))
  v_init <- paral_trans(v_init, c(1, rep(0, d)), eta)
  
  r_init <- norm_l(v_init)
  v_init_norm <- v_init / r_init
  
  while(!status){
    u <- runif(1)
    
    v <- c(0, rnorm(d, 0, sig))
    v <- paral_trans(v, c(1, rep(0, d)), eta)
    
    r <- norm_l(v)
    v_norm <- v / r
    
    y <- exp_fn(v, eta) 
      
    tau <- min( 1, abs( (sinh(r)/sinh(r_init))^(d-1) * r_init / r ) )
    
    iter <- iter + 1
    
    if (u <= tau){
      status <- TRUE
    }
    # if (iter >= 100){
    #   break
    # }
  }
  # if (status){
  #   return(y)
  # } else{
  #   return( exp_fn(v_init, eta) )
  # }
  return(y)
}


rRLap <- function(n, eta, sig){
  samp_list <- t(sapply(1:n, function(x){
    rRLap_single(eta, sig)
  }))
  return(samp_list)
}

rRLap_MCMC <- function(n, eta, sig, n_burn = 1000){
  u <- 1
  tau <- 0.5
  iter <- 0
  status <- FALSE
  d <- length(eta) - 1
  samp <- matrix(NA, n, d + 1)
  
  # draw initial sample
  
  v_init <- c(0, rnorm(d, 0, sig))
  v_init <- paral_trans(v_init, c(1, rep(0, d)), eta)
  
  r_init <- norm_l(v_init)
  v_init_norm <- v_init / r_init
  
  for (i in 1:(n + n_burn)){
    
    u <- runif(1)
    
    v <- c(0, rnorm(d, 0, sig))
    v <- paral_trans(v, c(1, rep(0, d)), eta)
    
    r <- norm_l(v)
    v_norm <- v / r
    
    tau <- min( 1, abs( (sinh(r)/sinh(r_init))^(d-1) * r_init / r ) )
    
    iter <- iter + 1
    
    if (u <= tau){
      v_init <- v
      r_init <- r
      y <- exp_fn(v, eta) 
    } else{
      y <- exp_fn(v_init, eta)
    }
    # print(y)
    if (i > n_burn){
      samp[i - n_burn, ] <- y
    }
  }
  return(samp)
}

# eta <- c(1, 0, 0)
# test <- rRLap(100, eta, 0.5)
# test <- rRLap_MCMC(100, eta, 0.5)
# test_plot <- scatterplot3d::scatterplot3d(test)
# test_plot$points3d(eta[1], eta[2], eta[3], col = "red")


# ======================== frechet mean computation ========================


frechet_mean <- function(data, n_iter = 1000, step_size = 0.5, lambda = 1e-10, result = FALSE){
  
  mean_old <- data[sample(1:nrow(data), 1), ]
  diff <- 1
  iter <- 1
  while (diff > lambda & iter <= n_iter){
    
    v_new <- apply(apply(data, 1, function(x){
      tang <- log_fn(x, mean_old) # need to double check
      # print(tang)
    }), 1, mean)
    
    # print("v_new")
    # print(v_new)
    
    mean_new <- exp_fn(step_size * v_new, mean_old)
    
    diff <- dist_fn(mean_new, mean_old)
    mean_old <- mean_new
    
    iter <- iter + 1
    # print(c(diff, mean_old, iter))
  }
  
  if(result){
    print("The Frechet mean is: ")
    print(mean_old)
  }
  
  return(mean_old)
}

# === testing for frechet_mean(): ===
# test <- rBall(100, c(1, 0, 0), 1.5)
# mean <- frechet_mean(test, n_iter = 10000, lambda = 1e-10, result = TRUE)
# test_plot <- scatterplot3d::scatterplot3d(test)
# test_plot$points3d(mean[1], mean[2], mean[3], col = "red")



