library(abind)
library(expm)
library(MASS)
library(plotly)
library(Rfast)
library(stats)
library(RSpectra)


# ======================== Sampling from Laplace Distribution ========================

rSphere <- function(n, d){
  norm_pt <- matrix(rnorm(n * d, 0, 1), n, d)
  samp_list <- apply(norm_pt, 1, function(x){
    x / sqrt(sum(x^2))
  })
  t(samp_list)
}

rLap <- function(n, d, mu, sig){
  dir_list <- rSphere(n, d)
  len_list <- rgamma(n, shape = d, rate = 1)
  samp_list <- sapply(1:n, function(i){
    mu + dir_list[i, ] * len_list[i] * sig 
  })
  t(samp_list)
}

# ======================== Matrix to Vector ========================

vec2mat <- function(x){
  p <- sqrt(1 + 8 * length(x))/ 2 - 0.5
  m <- matrix(0, p, p)
  m[upper.tri(m, diag = TRUE)] <- x
  m[lower.tri(m)] <- t(m)[lower.tri(m)]
  m
}

mat2vec <- function(m){
  d <- dim(m)[1]
  ind <- upper_tri(c(d, d), diag = TRUE)
  x <- m[ind]
  x
}


# ======================== Basic functions on Riemannian manifold ========================



tr_fn <- function(mat){
  return(sum(diag(mat)))
}

euclid_dist_fn <- function(x, y){
  return(sqrt(sum((x - y)^2)))
}

distance_fn <- function(p, q){
  return(sqrt(square_distance_fn(p, q)))
}

# discarded; too slow
# square_distance_fn <- function(p, q){
#   q_sqrt <- sqrtm(q)
#   # q_sqrt_inv <- solve(q_sqrt)
#   # q_sqrt_inv <- chol2inv(chol(q_sqrt))
#   q_sqrt_inv <- spdinv(q_sqrt)
#   # print(q_sqrt)
#   # print(q_sqrt_inv)
#   # print(p)
#   # print( q_sqrt_inv %*% p %*% q_sqrt_inv )
#   mat <- q_sqrt_inv %*% p %*% q_sqrt_inv
#   return( sum(log(eigen(mat)$values)^2) )
#   # return( tr_fn( mat_log %*% mat_log) )
# }

# from "Principal Geodesic Analysis on Symmetric Spaces:Statistics of Diffusion Tensors"
# square_distance_fn <- function(p, x){
#   y_eigen <- NA
#   p_eigen <- eigen(p)
#   g <- p_eigen$vector %*% diag(sqrt(p_eigen$values))
#   g_inverse <- diag(1 / sqrt(p_eigen$values)) %*% t(p_eigen$vector)
#   y <- g_inverse %*% x %*% t(g_inverse)
#   tryCatch(y_eigen <- eigen(y),
#            error = function(c){
#              print(p_eigen$values)
#              print(p_eigen$vector)
#              print(p)
#              print(y)
#            }
#   )
#   suppressWarnings(if(is.na(y_eigen)){
#     return(NA)
#   } else{
#     return( sum(log(y_eigen$values)^2) )  
#   })
# }

square_distance_fn <- function(p, x){
  y_eigen <- NA
  d <- dim(p)[1]
  p_eigen <- eigen(p, symmetric = TRUE)
  # g <- p_eigen$vector %*% diag(sqrt(p_eigen$values))
  g_inverse <- diag(1 / sqrt(p_eigen$values)) %*% t(p_eigen$vectors)
  y <- g_inverse %*% x %*% t(g_inverse)
  tryCatch(y_eigen <- eigen(y, symmetric = TRUE, only.values = TRUE),
           error = function(c){
             print(p_eigen$values)
             print(p_eigen$vectors)
             print(p)
             print(y)
           }
  )
  suppressWarnings(if(is.na(y_eigen)[1]){
    return(NA)
  } else if(is.nan(sum(log(y_eigen$values)^2))){
    print("y_eigen")
    print(y_eigen$values)
    print("p_eigen")
    print(p_eigen$values)
    print("x_eigen")
    print(eigen(x, symmetric = TRUE, only.values = TRUE)$values)
  } else{
    return( sum(log(y_eigen$values)^2) )  
  })
}

exp_fn <- function(p, v){
  p_sqrt <- sqrtm(p)
  # p_sqrt_inv <- solve(p_sqrt)
  # p_sqrt_inv <- chol2inv(chol(p_sqrt))
  p_sqrt_inv <- spdinv(p_sqrt)
  return( p_sqrt %*% expm( p_sqrt_inv %*% v %*% p_sqrt_inv ) %*% p_sqrt )
}
# 
log_fn <- function(p, q){
  q_sqrt <- sqrtm(q)
  # q_sqrt_inv <- solve(q_sqrt)
  # q_sqrt_inv <- chol2inv(chol(q_sqrt))
  q_sqrt_inv <- spdinv(q_sqrt)
  return( q_sqrt %*% logm( q_sqrt_inv %*% p %*% q_sqrt_inv ) %*% q_sqrt )
}

# from "Principal Geodesic Analysis on Symmetric Spaces:Statistics of Diffusion Tensors"
exp_fn_1 <- function(p, v){
  p_eigen <- eigen(p)
  g <- p_eigen$vector %*% diag(sqrt(p_eigen$values))
  g_inverse <- diag(1 / sqrt(p_eigen$values)) %*% t(p_eigen$vector)
  y <- g_inverse %*% v %*% t(g_inverse)
  y_eigen <- eigen(y)
  left_mat <- g %*% y_eigen$vector
  return( left_mat %*% diag(exp(y_eigen$values)) %*% t(left_mat) )
  # return( p_eigen$vector %*% diag(exp(y_eigen$values) * p_eigen$values) %*% t(p_eigen$vector) )
}

# from "Principal Geodesic Analysis on Symmetric Spaces:Statistics of Diffusion Tensors" 
# log_fn <- function(p, x){
#   p_eigen <- eigen(p)
#   g <- p_eigen$vector %*% diag(sqrt(p_eigen$values))
#   g_inverse <- diag(1 / sqrt(p_eigen$values)) %*% t(p_eigen$vector)
#   y <- g_inverse %*% x %*% t(g_inverse)
#   y_eigen <- eigen(y)
#   left_mat <- g %*% y_eigen$vector
#   return( left_mat %*% diag(log(y_eigen$values)) %*% t(left_mat) )
# }



# ======================== sampling from distributions ========================

sinh_fn <- function(x){
  d <- length(x)
  res <- 1
  for (i in 2:d){
    for (j in 1:(i - 1)){
      res <- res * sinh(abs(x[i] - x[j])/2)
    }
  }
  return(res)
}

sinh_ratio_fn <- function(x_new, x){
  d <- length(x)
  res <- 1
  for (i in 2:d){
    for (j in 1:(i - 1)){
      res <- res * sinh(abs(x_new[i] - x_new[j])/2) / sinh(abs(x[i] - x[j])/2) 
    }
  }
  return(res)
}

density_ratio_fn <- function(x, x_new, sigma, type = "gauss"){
  # print("========================= density_ratio_fn =========================")
  sinh_ratio <- sinh_fn(x_new) / sinh_fn(x)
  # sinh_ratio <- sinh_ratio_fn(x_new, x)
  if (type == "gauss"){
    return(exp( (- sum(x_new^2) + sum(x^2) ) / (2 * sigma^2) ) * sinh_ratio)
  } else if (type == "laplace"){
    return(exp( (- sqrt(sum(x_new^2)) + sqrt(sum(x^2)) ) / sigma ) * sinh_ratio)
  }
}

new_pt_fun <- function(x, d, sigma, type){
  x_new <- NA
  suppressWarnings( while(is.na(x_new)[1]){
  # while(is.na(x_new)){
    # print(x)
    # x_new <- mvrnorm(1, x, 0.5^2 * diag(d))
    # print(x_new)
    # x_new <- sapply(1:d, function(i) rlnorm(1, exp(x[i]), sigma))
    # print(sigma)
    # print(x_new)
    # readline(prompt="Press [enter] to continue")
    x_new <- mvrnorm(1, x, (sigma / 2^d )^2 * diag(d))
    # if (type == "laplace")
    # x_new <- mvrnorm(1, x, sigma^2 * diag(d))
    p <- density_ratio_fn(x, x_new, sigma, type)
    # print(x)
    # print(x_new)
    # print(p)
    if (is.nan(p)){
      p <- 1
    }
    # print("========================= new_pt_fn =========================")
    # print(c(x_new, x, p))
    if ( runif(1) > min(p, 1) ){
      x_new <- NA
    # print(x_new)
    }
  } )
  return(x_new)
}

mh_sample_fn <- function(n, n_burn, d, sigma, type){
  
  n_total <- n + n_burn
  samp_list <- c()
  
  x <- rep(0, d) # initialize
  samp_list <- rbind(samp_list, x)
  
  while(nrow(samp_list) < n_total){
    x_new <- new_pt_fun(x, d, sigma, type)
    samp_list <- rbind(samp_list, x_new)
    x <- x_new
    # print("new point")
    # print(nrow(samp_list))
  }
  
  return( samp_list[(n_burn + 1):n_total, ])
}


# === testing for mh_sample_fn(): ===
# test <- mh_sample_fn(2000, 5000, 2, 2, "gauss")
# plot(test[,1], test[,2])
# plot(exp(test[,1]), exp(test[,2]))
# hist(apply(test, 1, sum), breaks = 25, freq = FALSE)
# points(seq(-3,3,0.1), dnorm(seq(-3,3,0.1)), type = "l")


rOrtho_mat <- function(n, d){
  res <- lapply(1:n, function(x){
    gauss_mat <- matrix(rnorm(d * d), d, d)
    QR <- qr(gauss_mat) 
    Q <- qr.Q(QR)
    R <- qr.R(QR)
    for (i in 1:d){
      if(R[i,i] < 0){
        Q[i, ] <- -1 * Q[i, ]
      }
    }
    Q
  })
  return(res)
}

# === testing for rOrtho_mat(): ===
# test <- rOrtho_mat(300, 2)
# test <- t(sapply(test, function(x) x[lower.tri(x, diag = TRUE)]))
# test_plot <- scatterplot3d::scatterplot3d(test)

rSPD_dist <- function(n, d, mu, sigma, type, n_burn = 3000){
  ortho_mat <- rOrtho_mat(n, d)
  # print(ortho_mat)
  p_samp <- mh_sample_fn(n, n_burn, d, sigma, type)
  p_samp <- matrix(p_samp, n, d)
  samp_list <- lapply(1:n, function(x){
    t(sqrtm(mu)) %*% t(ortho_mat[[x]]) %*% diag(exp(p_samp[x, ])) %*% ortho_mat[[x]] %*% sqrtm(mu)
    # t(ortho_mat[[x]]) %*% diag(exp(p_samp[x, ])) %*% ortho_mat[[x]]
  })
  return(samp_list)
}

# samp <- rSPD(500, 2, 1.5)
# mean <- frechet_mean(samp)
# test <- rSPD_dist(500, 2, mean, 0.5, "gauss")
# t(sapply(test, function(x) eigen(x)$values))
# test <- t(sapply(test, function(x) x[lower.tri(x, diag = TRUE)]))
# test <- data.frame(test)
# test_plot <- plot_ly(test, x = ~X1, y = ~X2, z = ~X3)
# test_plot <- test_plot %>% add_markers()
# test_plot <- test_plot %>% layout(scene = list(xaxis = list(title = 'X'),
#                                                yaxis = list(title = 'Y'),
#                                                zaxis = list(title = 'Z')))
# # test_plot <- test_plot %>% add_trace(x = 1, y = 0, z = 1, name = 'center', mode = 'markers')
# test_plot <- test_plot %>% add_trace(x = mean[1,1], y = mean[2,1], z = mean[2,2], name = 'center', mode = 'markers')
# test_plot




# ======================== monte carlo integration ========================

# log_density_ratio_fn <- function(mu_top, mu_bottom, x, sigma, eps, type = "gauss"){
#   if (type == "gauss"){
#     return( ( - distance_fn(x, mu_top)^2 + distance_fn(x, mu_bottom)^2 ) >= (2 * sigma^2 * eps) )  
#   } else if (type == "laplace"){
#     return( ( - distance_fn(x, mu_top) + distance_fn(x, mu_bottom) ) >= (sigma * eps) )  
#   }
# }
# 
# integral_fn <- function(samp_pts, eps, mu_1, mu_2, sigma, type){
#   
#   count_list <- sapply(samp_pts, function(x){
#     if ( log_density_ratio_fn(mu_1, mu_2, x, sigma, eps, type) ){
#       return(1)
#     } else{
#       return(0)
#     }
#   })
#   return( sum(count_list)/n )
# }

distance_diff_fn <- function(mu_top, mu_bottom, x, type = "gauss"){
  if (type == "gauss"){
    return( - distance_fn(x, mu_top)^2 + distance_fn(x, mu_bottom)^2 )  
  } else if (type == "laplace"){
    return( - distance_fn(x, mu_top) + distance_fn(x, mu_bottom) )  
  }
}

log_density_ratio_fn <- function(dist_diff, sigma, eps, type = "gauss"){
  if (type == "gauss"){
    return( dist_diff >= (2 * sigma^2 * eps) )  
  } else if (type == "laplace"){
    return( dist_diff >= (sigma * eps) )  
  }
}

integral_fn <- function(dist_diff, eps, sigma, type){
  n <- length(dist_diff)
  count_list <- sapply(dist_diff, function(x){
    if ( log_density_ratio_fn(x, sigma, eps, type) ){
      return(1)
    } else{
      return(0)
    }
  })
  return( sum(count_list)/n )
}


delta_fn <- function(u, eps){
  return( pnorm( - eps / u + u/2 ) - exp(eps) * pnorm( - eps / u - u/2 ) )
}

intial_interval <- function(f, target, eps){
  k <- 0
  while ( f(2^k, eps) <= target ){
    k <- k + 1
  }
  return( c(0, 2^k) )
}

binary_search <- function(f, target, l, h, precis, eps){
  while ( (h - l) > precis ){
    m <- (l + h) /2
    if (f(m, eps) <= target){
      l <- m
    } else{
      h <- m 
    }
  }
  return(l)
}




# gdp_u <- function(d, sensi, sigma, r = 1, n = 3000, n_burn = 3000, eps_max = 3, eps_step = 0.05){
#   
#   mu_1 <- diag(d)
#   # print(mu_1)
#   mu_2 <- exp_fn( mu_1, diag(d) * sensi / sqrt(d) )
#   
#   u_list <- c()
#   eps_list <- seq(eps_step, eps_max, eps_step)
#   
#   for (eps in eps_list){
#     print(paste0("========== ", eps, " ========== "))
#     print(Sys.time())
#     lhs <- integral_fn_1(eps, mu_1, mu_2, sigma, d, n, n_burn) - integral_fn_2(eps, mu_1, mu_2, sigma, d, n, n_burn)
#     # print("1")
#     if (lhs > 0){
#       interval <- intial_interval(delta_fn, lhs, eps)
#       u <- binary_search(delta_fn, lhs, interval[1], interval[2], 0.001, eps)
#       # print("2")
#       u_list <- cbind(u_list, u)
#     } else{
#       u_list <- cbind(u_list, 0)
#     }
#   }
#   return(max(u_list))
# }

# gdp_u <- function(d, sensi, sigma, r = 1, n = 3000, n_burn = 3000, eps_max = 3, eps_step = 0.05, type){
#   
#   mu_1 <- diag(d)
#   mu_2 <- exp_fn( mu_1, diag(d) * sensi / sqrt(d) )
#   
#   samp_pts_1 <- rSPD_dist(n, d, mu_1, sigma, type = type, n_burn)
#   samp_pts_2 <- rSPD_dist(n, d, mu_1, sigma, type = type, n_burn)
#   
#   u_list <- c()
#   eps_list <- seq(eps_step, eps_max, eps_step)
#   
#   for (eps in eps_list){
#     print(paste0("========== ", eps, " ========== "))
#     print(Sys.time())
#     lhs <- integral_fn_1(samp_pts_1, eps, mu_1, mu_2, sigma, d, n, n_burn, type) - integral_fn_2(samp_pts_2, eps, mu_1, mu_2, sigma, d, n, n_burn, type)
#     # print("1")
#     if (lhs > 0){
#       interval <- intial_interval(delta_fn, lhs, eps)
#       u <- binary_search(delta_fn, lhs, interval[1], interval[2], 0.001, eps)
#       # print("2")
#       u_list <- cbind(u_list, u)
#     } else{
#       u_list <- cbind(u_list, 0)
#     }
#   }
#   return(max(u_list))
# }

# gdp_u <- function(sensi, sigma, r = 1, n = 3000, n_burn = 3000, eps_max = 3, eps_step = 0.05, type = "gauss", width = 1){
#   
#   mu_1 <- c(r, 0, 0)
#   mu_2 <- c(cos(sensi) / r, sqrt(r^2 - (cos(sensi) / r)^2), 0)
#   
#   u_list <- c()
#   
#   samp_pts_1 <- mh_sample_fn(n, n_burn, mu_1, sigma, r, type = type, width)
#   samp_pts_2 <- mh_sample_fn(n, n_burn, mu_2, sigma, r, type = type, width)
#   
#   eps_list <- seq(eps_step, eps_max, eps_step)
#   
#   for (eps in eps_list){
#     lhs <- integral_fn_1(samp_pts_1, eps, mu_1, mu_2, sigma, n, n_burn, r, type) - integral_fn_2(samp_pts_2, eps, mu_1, mu_2, sigma, n, n_burn, r, type)
#     if (lhs > 0){
#       interval <- intial_interval(delta_fn, lhs, eps)
#       u <- binary_search(delta_fn, lhs, interval[1], interval[2], 0.001, eps)
#       u_list <- cbind(u_list, u)
#     } else{
#       u_list <- cbind(u_list, 0)
#     }
#   }
#   return(max(u_list))
# }


gdp_u_alt <- function(sensi, sigma, n = 1000, n_burn = 3000, eps_max = 10, n_eps = 1000, type = "gauss", m, d = 2){
  
  
  mu_1 <- diag(d)
  mu_2 <- exp_fn( mu_1, diag(d) * sensi / sqrt(d) )
  g <- rSPD(1, d, 10)[[1]]
  mu_1 <- g %*% mu_1 %*% t(g)
  mu_2 <- g %*% mu_2 %*% t(g)
  
  u_list <- c()
  lhs_list <- rep(0, n_eps)
  lhs_list_1 <- rep(0, n_eps)
  lhs_list_2 <- rep(0, n_eps)
  
  eps_list <- seq(eps_max/n_eps, eps_max, eps_max/n_eps)
  
  for (i in 1:m){
    
    print(Sys.time())
    print(paste0("============= m: ", i, " ============= "))
    
    samp_pts_1 <- rSPD_dist(n, d, mu_1, sigma, type = type, n_burn)
    samp_pts_2 <- rSPD_dist(n, d, mu_1, sigma, type = type, n_burn)
    
    dist_diff_1 <- sapply(samp_pts_1, function(x){
      distance_diff_fn(mu_1, mu_2, x, type = "gauss")
    })
    dist_diff_1 <- dist_diff_1[!is.na(dist_diff_1)]
    
    dist_diff_2 <- sapply(samp_pts_2, function(x){
      distance_diff_fn(mu_1, mu_2, x, type = "gauss")
    })
    dist_diff_2 <- dist_diff_2[!is.na(dist_diff_2)]
    
    # print(dist_diff_1)
    # print(dist_diff_2)
    
    # print("step 1")
    # print(Sys.time())
    
    lhs_list_tmp <- c()
    lhs_list_1_tmp <- c()
    lhs_list_2_tmp <- c()
    
    for (eps in eps_list){
      lhs_1 <- integral_fn(dist_diff_1, eps, sigma, type)
      lhs_2 <- integral_fn(dist_diff_2, eps, sigma, type)
      
      
      suppressWarnings(if (!is.na(samp_pts_1)){
        if (lhs_1 == 0 & lhs_2 == 0){
          break
        } 
      })
      
      lhs <-  max(lhs_1 - exp(eps) * lhs_2, 0)
      lhs_list_tmp <- cbind(lhs_list_tmp, lhs)
      lhs_list_1_tmp <- cbind(lhs_list_1_tmp, lhs_1)
      lhs_list_2_tmp <- cbind(lhs_list_2_tmp, lhs_2)
    }
    
    # print("step 2")
    # print(Sys.time())
    
    lhs_list_1_tmp <- c(lhs_list_1_tmp, rep(0, length(eps_list) - length(lhs_list_1_tmp)))
    lhs_list_2_tmp <- c(lhs_list_2_tmp, rep(0, length(eps_list) - length(lhs_list_2_tmp)))
    lhs_list_tmp <- c(lhs_list_tmp, rep(0, length(eps_list) - length(lhs_list_tmp)))
    
    lhs_list_1 <- (lhs_list_1_tmp + lhs_list_1)
    lhs_list_2 <- (lhs_list_2_tmp + lhs_list_2)
    lhs_list <- (lhs_list_tmp + lhs_list)
    
  }
  
  lhs_list_1 <- lhs_list_1 / m
  lhs_list_2 <- lhs_list_2 / m
  lhs_list <- lhs_list / m
  
  i = 1
  # print(lhs_list_1)
  # print(lhs_list_2)
  # print(lhs_list)
  for (eps in eps_list){
    # lhs <- lhs_list_1[i] - exp(eps) * lhs_list_2[i]
    lhs <- lhs_list[i]
    # lhs_list <- cbind(lhs_list, lhs)
    
    # print("step 3 start")
    # print(Sys.time())
    # print(lhs)
    if (lhs > 0 & lhs < 1){
      interval <- intial_interval(delta_fn, lhs, eps)
      u <- binary_search(delta_fn, lhs, interval[1], interval[2], 0.001, eps)
      u_list <- cbind(u_list, u)
    } else{
      u_list <- cbind(u_list, 0)
    }
    i <- i + 1
    
    # print("step 3")
    # print(Sys.time())
  }
  
  
  
  res_list <- list()
  res_list$u <- max(u_list)
  res_list$u_list <- c(u_list)
  res_list$eps_list <- eps_list[1:length(u_list)]
  res_list$lhs_list <- c(lhs_list)
  res_list$lhs_list_1 <- c(lhs_list_1)
  res_list$lhs_list_2 <- c(lhs_list_2)
  return(res_list)
}


sensi_fn <- function(n, r){
  return( 2 * r / n )
}

#=== testing sensi_fn() & gdp_u() ===
# n <- 10
# d <- 2
# r <- 1.5
# sensi <- sensi_fn(n, r)
# sigma <- 2 * sensi
# eps_max <- max(5 * sensi / sigma + sensi^2 / (2 * sigma^2), 10)
# print(Sys.time())
# u_list <- sapply(1:10, function(x){
#   u <-gdp_u_alt(sensi, sigma, n = 1000, n_burn = 3000, eps_max = eps_max, n_eps = 1000, type = "gauss", m = 100, d = d)$u
# })

# [1] 0.5644531 0.5830078 0.5566406 0.5791016 0.6015625 0.5761719 0.5849609 0.6142578 0.6005859 0.5761719

# 0.5830078 # diag(d)
# 0.1210938 
# 0.3798828
# 0.01269531 # 0.1 * diag(d)
# 1.621094; 1.641602

# print(Sys.time())
# u <- gdp_u(sensi, 0.02, eps_max = 10) # 0.7119141
# u <- gdp_u(sensi, 0.02, n = 5000, eps_max = 10) # 0.7119141

# sensi <- sensi_fn(300, 1.5)
# sigma <- 1 * sensi
# eps_max <- 10
# print(Sys.time())
# res_list <- gdp_u_alt(sensi = sensi, sigma = sigma, n = 1000, n_burn = 3000, eps_max = eps_max, n_eps = 1000, type = "gauss", m = 100, d = 2); res_list$u # 1.15625
# print(Sys.time())
# lhs_list_true <- apply(cbind(res_list$eps_list, sensi/sigma), 1, function(x) delta_fn(x[2], x[1]))
# plot(res_list$eps_list, lhs_list_true, xlab = "eps", ylab = "delta", type ="l")
# points(res_list$eps_list, res_list$lhs_list, col = "red")
# plot(res_list$eps_list, lhs_list_true, xlab = "eps", ylab = "delta", type ="l", ylim = c(0, 0.005))
# points(res_list$eps_list, res_list$lhs_list, col = "red")





# ======================== pure DP comparison ========================


laplace_compare <- function(d, sensi, sigma, r = 1.5, n = 3000, n_burn = 3000, eps_max = 3, eps_step = 0.05){
  u_mcmc <- gdp_u(d, sensi, sigma, r, n, n_burn, eps_max = 3, eps_step, type = "laplace")
  eps <- sensi / sigma
  u_true <- -2 * qnorm( 1/(1 + exp(eps)) )
  return(c(u_true, u_mcmc))
}

# sensi <- sensi_fn(300, 1.5)
# laplace_compare(2, sensi, 0.02)


# ======================== Laplace mechanism vs Gauss mechanism ========================

frechet_mean <- function(data, d, n_iter = 500, step_size = 0.5, lambda = 1e-5){
  
  mean_old <- data[[sample(1:length(data), 1)]]
  diff <- 1
  iter <- 1
  # print(mean_old)
  while (diff > lambda & iter <= n_iter){
    
    v_new <- matrix(apply(sapply(data, function(x){
      log_fn(x, mean_old)
    }), 1, mean), d, d)
    
    # print(v_new)
    mean_new <- exp_fn(mean_old, step_size * v_new)
    
    diff <- distance_fn(mean_new, mean_old)
    mean_old <- mean_new
    
    iter <- iter + 1
    # print(c(diff, mean_old, iter))
  }
  
  return(mean_old)
}

# === testing for frechet_mean(): ===
# test <- rSPD(100, 2, 1.5)
# mean <- frechet_mean(test)
# test <- t(sapply(test, function(x) x[lower.tri(x, diag = TRUE)]))
# test_plot <- scatterplot3d::scatterplot3d(test)
# test_plot$points3d(mean[1,1], mean[2,1], mean[2,2], col = "red")

rSPD <- function(n, k, r){
  data <- array(0, c(k, k, 0))
  if (k > 3){
    sig_sq <- 1 / k
  }else{
    sig_sq <- 1 
  }
  while (dim(data)[3] < n){
    data_ind <- rWishart(1, k, sig_sq* diag(rep(1, k)))
    #if (square_distance_fn( matrix(data_ind, k, k) , diag(rep(1, k)) ) < r^2){
    dist_sq <- sum(log(eigen(data_ind[,,1], symmetric = TRUE, only.values = TRUE)$values) ^ 2)
    if ( dist_sq <= r^2){
      data <- abind(data, data_ind)
    }
    # else{
    #   print("reject")
    #   print(data_ind[,,1])
    #   print(dist_sq)
    #   readline(prompt="Press [enter] to continue")
    # }
  }
  data <- lapply(seq(dim(data)[3]), function(x) data[, , x])
  return(data)
}

# === testing for rSPD(): ===
# test <- rSPD(100, 2, 1.5)
# test <- t(sapply(test, function(x) x[lower.tri(x, diag = TRUE)]))
# test_plot <- scatterplot3d::scatterplot3d(test)

gdp_to_dp <- function(u){
  p <- pnorm(- u/2)
  return( log((1 - p) / p) )
}


frechet_mean_compare <- function(n, d, r, sigma, u = NA){
  
  data <- rSPD(n, d, r)
  sensi <- sensi_fn(n, r)
  
  mean_true <- frechet_mean(data)
  
  if (is.na(u)){
    eps_max <- eps_max <- max(5 * sensi / sigma + sensi^2 / (2 * sigma^2), 10)
    res_list <- gdp_u_alt(sensi, sigma, n = 1000, n_burn = 3000, eps_max = eps_max, n_eps = 1000, type = "gauss", m = 100, d = d)  
    u <- res_list$u
  }
  eps <- gdp_to_dp(u)

  
  mean_gdp <- rSPD_dist(1, d, mu = mean_true, sigma = sigma, "gauss", n_burn = 3000)[[1]]
  mean_laplace <- rSPD_dist(1, d, mu = mean_true, sigma = sensi/eps, "laplace", n_burn = 3000)[[1]]
  
  sensi_euclid <- exp(sensi) - 1
  ind <- upper_tri(c(d, d), diag = TRUE)
  mean_true_vec <- mean_true[ind]
  mean_gdp_euclid <- mean_true_vec + sapply(1:(d * (d + 1) / 2), function(x) rnorm(1, 0, sensi_euclid / u))
  
  mean_gdp_vec <- mean_gdp[ind]
  mean_laplace_vec <- mean_laplace[ind]
  
  # print(mean_gdp)
  # print(mean_laplace)
  
  return(c( distance_fn(mean_gdp, mean_true), distance_fn(mean_laplace, mean_true), euclid_dist_fn(mean_gdp_vec, mean_true_vec), euclid_dist_fn(mean_laplace_vec, mean_true_vec), euclid_dist_fn(mean_gdp_euclid, mean_true_vec) ))
}


# n <- 10
# r <- 5
# d <- 2
# sensi <- sensi_fn(n, r)
# sigma <- 1 * sensi
# eps_max <- eps_max <- max(5 * sensi / sigma + sensi^2 / (2 * sigma^2), 10)
# res_list <- gdp_u_alt(sensi, sigma, n = 1000, n_burn = 3000, eps_max = eps_max, n_eps = 1000, type = "gauss", m = 100, d = d)
# u <- res_list$u
# frechet_mean_compare(n, 2, r, sigma, u = u)

