rm(list = ls())
library(ExtDist)

path <- ""
res_path <- ""

source(paste0(path, "spd_functions.R"))

d <- 4
r <- 1.5
n <- 40

sens <- 2 * r / n
mu_list <- c(seq(0.1, 0.7, 0.1), 1, 1.5, 2)

res_list <- matrix(NA, nrow = length(mu_list), ncol = 8)
ind = 1
set.seed(31415)
for (mu in mu_list){
  lap_sig <- sens / log( (1 - pnorm(- mu / 2)) / pnorm(- mu / 2) )
  print(paste0("============ ", mu, " ============"))
  fix_point <- diag(rep(1, d))
  dat <- rSPD(n, d, r)
  print(Sys.time())
  true_p <- frechet_mean(dat, d)
  
  dist_list <- c()
  time_list <- c()
  for (i in 1:100){
    start_1 <- Sys.time()
    p_dp <- expm::expm( vecd_inverse( vecd(log_fn(fix_point, true_p)) + rnorm(d*(d+1)/2, 0, sd = sens / mu) ) )
    start_2 <- Sys.time()
    p_dp_2 <- rSPD_dist(1, d, mu = true_p, sigma = lap_sig, "laplace", n_burn = 10000)[[1]]
    start_3 <- Sys.time()
    dist_list <- rbind(dist_list, c(distance_fn(true_p, p_dp), distance_fn(true_p, p_dp_2)))
    time_list <- rbind(time_list, c(start_2 - start_1, start_3 - start_2))
    print(dist_list[i, ])
  }
  
  dist_mean <- apply(dist_list, 2, function(x){
    x <- x[is.finite(x)]
    mean(x)
  })
  
  dist_std <- apply(dist_list, 2, function(x){
    x <- x[is.finite(x)]
    sd(x)
  })
  
  time_mean <- apply(time_list, 2, function(x){
    x <- x[is.finite(x)]
    mean(x)
  })
  
  time_std <- apply(time_list, 2, function(x){
    x <- x[is.finite(x)]
    sd(x)
  })
  
  
  print(c(dist_mean, dist_std))
  res_list[ind, ] <- c(dist_mean, dist_std, time_mean, time_std) 
  ind = ind + 1
}

res_list <- cbind(mu_list, res_list)

write.csv(res_list, file = paste0(res_path, "GDP_", d, ".csv"))