rm(list = ls())
library(ExtDist)

path <- ""
res_path <- ""

source(paste0(path, "spd_functions.R"))

d <- 6
r <- 1.5
n <- 40

# d <- 10
# r <- 45
# n <- 1500
sens <- 2 * r / n
eps_list <- c(seq(0.1, 0.7, 0.1), 1, 1.5, 2)

res_list <- matrix(NA, nrow = length(eps_list), ncol = 8)
ind = 1
for (eps in eps_list){
  lap_sig <- sens / eps
  print(lap_sig)
  set.seed(1)
  # fix_point <- rSPD(1, d, r)[[1]]
  fix_point <- diag(rep(1, d))
  # set.seed(NULL)
  dat <- rSPD(n, d, r)
  # dat <- rSPD(1, d, r)
  print(Sys.time())
  true_p <- frechet_mean(dat, d)
  # true_p <- dat[[1]]
  
  dist_list <- c()
  time_list <- c()
  for (i in 1:100){
    start_1 <- Sys.time()
    true_p_tang <- log_fn(fix_point, true_p)
    true_p_tang_vec <- mat2vec(true_p_tang)
    p_dp_tang <- true_p_tang_vec + rLap(1, d * (d + 1) / 2, mu = 0, sig = lap_sig)
    p_dp <- exp_fn(fix_point, vec2mat(p_dp_tang))
    start_2 <- Sys.time()
    p_dp_2 <- rSPD_dist(1, d, mu = true_p, sigma = lap_sig, "laplace", n_burn = 10000)[[1]]
    start_3 <- Sys.time()
    dist_list <- rbind(dist_list, c(distance_fn(true_p, p_dp), distance_fn(true_p, p_dp_2)))
    time_list <- rbind(time_list, c(start_2 - start_1, start_3 - start_2))
    print(dist_list[i, ])
  }
  
  # print(dist_list)
  dist_mean <- apply(dist_list, 2, function(x){
    x <- x[is.finite(x)]
    mean(x)
  })
  
  dist_std <- apply(dist_list, 2, function(x){
    x <- x[is.finite(x)]
    sd(x)
  })
  
  time_mean <- apply(time_list, 2, function(x){
    x <- x[is.finite(x)]
    mean(x)
  })
  
  time_std <- apply(time_list, 2, function(x){
    x <- x[is.finite(x)]
    sd(x)
  })
  
  
  print(c(dist_mean, dist_std))
  res_list[ind, ] <- c(dist_mean, dist_std, time_mean, time_std) 
  ind = ind + 1
}

res_list <- cbind(eps_list, res_list)

write.csv(res_list, file = paste0(res_path, "vanilla_DP_", d, "_", n, ".csv"))