library(expm)


vecd <- function(m){
  return(c(diag(m), sqrt(2)* m[upper.tri(m, diag = FALSE)]))
}


vecd_inverse <- function(u){
  m <- sqrt(1 + 8 * length(u))/ 2 - 0.5
  mat <- matrix(0, m, m)
  diag(mat) <- u[1:m]
  mat[upper.tri(mat, diag = FALSE)] <- u[-(1:m)] / sqrt(2)
  mat[lower.tri(mat)] <- t(mat)[lower.tri(mat)]
  return(mat)
}

tr_fn <- function(mat){
  return(sum(diag(mat)))
}

dist_le <- function(X, Y){
  log_diff <- log_mat(X) - log_mat(Y)
  return(sqrt(tr_fn(log_diff %*% log_diff)))
}

exp_mat <- function(X){
  return(expm::expm(X))
}

log_mat <- function(X){
  return(expm::logm(X, method = "Eigen"))
}


frechet_mean_le <- function(data, d, n_iter = 500, step_size = 0.5, lambda = 1e-5){
  dat_log <- sapply(data, function(x) log_mat(x))
  dat_log <- dat_log[, is.finite(colSums(dat_log))]
  dat_log_mean <- matrix(rowMeans(dat_log), d, d)
  dat_mean <- exp_mat(dat_log_mean)
  return(dat_mean)
}

rSPD_single <- function(d, r){
  D <- diag(runif(d, exp(- r / sqrt(d)), exp( r / sqrt(d))))
  E <- ICtest::rorth(d)
  sample <- E %*% D %*% t(E)
  return(sample)
}

rSPD_alt <- function(n,d,r){
  sample_list <- lapply(1:n, function(x) rSPD_single(d, r))
  return(sample_list)
}

censor <- function(x, r){
  x_norm <- sqrt(sum(x^2))
  if(x_norm > r){
    return(x/x_norm * r)
  } else{
    return(x)
  }
}

rSPD <- function(n, k, r, center){
  data <- array(0, c(k, k, 0))
  while (dim(data)[3] < n){
    data_ind <- rWishart(1, k, diag(rep(1/k, k)))
    dist <- dist_le(data_ind[,,1], center)
    if ( dist <= r){
      data <- abind(data, data_ind)
    }
  }
  data <- lapply(seq(dim(data)[3]), function(x) data[, , x])
  return(data)
}

rSphere <- function(n, d){
  norm_pt <- matrix(rnorm(n * d, 0, 1), n, d)
  samp_list <- apply(norm_pt, 1, function(x){
    x / sqrt(sum(x^2))
  })
  t(samp_list)
}

rLap <- function(n, d, mu, sig){
  dir_list <- rSphere(n, d)
  len_list <- rgamma(n, shape = d, rate = 1)
  samp_list <- sapply(1:n, function(i){
    mu + dir_list[i, ] * len_list[i] * sig 
  })
  t(samp_list)
}

rSPD_dis_single <- function(d, r, sig){
  
  
  length <- r + 1
  while (length >= r){
    noise <- rLap(1, d * (d+1) / 2, 0,  sig = sig)
    length <- sqrt(sum(c(vec2mat(noise))^2))
    samp <- exp_mat(vec2mat(noise))
  }
  return(samp)
}

rSPD_dis <- function(n,d,r,sig){
  sample_list <- lapply(1:n, function(x) rSPD_dis_single(d, r, sig))
  return(sample_list)
}

rSPD_dist_le <- function(n, d, center, sigma, type = "gauss"){
  sample_list <- lapply(center, function(x){
    if (type == "gauss"){
      noise <- rnorm(d * (d + 1) / 2, 0 , sd = sigma)  
    } else if (type == "laplace"){
      noise <- rLap(1, d * (d+1) / 2, 0,  sig = sigma)
    }
    samp <- exp_mat(log_mat(x) + vecd_inverse(noise))
  })
  return(sample_list)
}


