rm(list = ls())
library(ExtDist)
library(exactLTRE)

path <- ""
data_path <- ""
res_path <- ""

source(paste0(path, "spd_functions_le.R"))

class <- 3
eta <- 1e-6
d <- 5
MyArray <- readRDS(paste0(data_path, "spd_train_", class, "_", d,  ".rds"))
dat <- lapply(seq(dim(MyArray)[3]), function(x) MyArray[ , , x])
p <- 0.05
sub_ind <- sample(1:length(dat), round(p*length(dat))) 
dat_sub <- lapply(seq(dim(MyArray)[3])[sub_ind], function(x) MyArray[ , , x])
true_p_sub <- frechet_mean_le(dat_sub, d = d)

print(Sys.time())
true_p <- data.matrix(read.csv(file = paste0(data_path, "octmnist_frechet_mean_", class, "_",d,  ".csv"))[,-1])


print(Sys.time())

n <- length(dat)
if(d == 5){
  r <- sqrt(d) * max(abs(log(eta)), abs(log(5 * 255^2 + eta)))  
} else if (d == 4){
  r <- sqrt(d) * max(abs(log(eta)), abs(log(4 * 255^2 + eta)))
} 

sens <- 2 * r / n
mu_list <- c(seq(0.1, 0.7, 0.1), 1, 1.5, 2)

res_list <- matrix(NA, nrow = length(mu_list), ncol = 8)
ind = 1

set.seed(class)
for (mu in mu_list){
  lap_sig <- sens / log( (1 - pnorm(- mu / 2)) / pnorm(- mu / 2) )
  
  print(lap_sig)
  
  dist_list <- c()
  time_list <- c()
  for (i in 1:100){
    print(Sys.time())
    start_1 <- Sys.time()
    p_dp <- rSPD_dist_le(1, d, list(true_p), sigma = sens / mu, type = "gauss")[[1]]
    start_2 <- Sys.time()
    p_dp_2 <- rSPD_dist_le(1, d, list(true_p), sigma = lap_sig, type = "laplace")[[1]]
    start_3 <- Sys.time()
    dist_list <- rbind(dist_list, c(dist_le(true_p, p_dp), dist_le(true_p, p_dp_2)))
    time_list <- rbind(time_list, c(start_2 - start_1, start_3 - start_2))
  }
  
  dist_mean <- apply(dist_list, 2, function(x){
    x <- x[is.finite(x)]
    mean(x)
  })
  
  dist_std <- apply(dist_list, 2, function(x){
    x <- x[is.finite(x)]
    sd(x)
  })
  
  time_mean <- apply(time_list, 2, function(x){
    x <- x[is.finite(x)]
    mean(x)
  })
  
  time_std <- apply(time_list, 2, function(x){
    x <- x[is.finite(x)]
    sd(x)
  })
  
  
  print(c(dist_mean, dist_std))
  res_list[ind, ] <- c(dist_mean, dist_std, time_mean, time_std) 
  ind = ind + 1
}

res_list <- cbind(mu_list, res_list)

write.csv(res_list, file = paste0(res_path, "octmnist_GDP_le_", class, "_", d, ".csv"))