library(expm)


tr_fn <- function(mat){
  return(sum(diag(mat)))
}

f_norm <- function(X, square = FALSE){
  X_norm_sq <- sum(X^2)
  if (square){
    return(X_norm_sq)  
  } else{
    return(sqrt(X_norm_sq))
  }
}

diag_half <- function(W){
  diag(W) <- diag(W) / 2
  W[upper.tri(W, diag = FALSE)] <- 0
  return(W)
}

d_chol <- function(W, P, identity = TRUE){
  if (identity){
    return(diag_half(W))
  } else{
    P_chol <- t(chol(P))
    P_chol_inv <- solve(P_chol)
    return(P_chol %*% diag_half( P_chol_inv %*% W %*% t(P_chol_inv) ))
  }
}

d_chol_inv <- function(W, P, identity = TRUE){
  if (identity){
    return(t(X) + X)
  } else{
    return(P %*% t(X) + X %*% t(P))
  }
}

exp_lower <- function(X, L, identity = TRUE){
  if (identity){
    diag_exp <- exp(diag(X))
    diag_exp <- sapply(diag_exp, function(x) max(x, 1e-06)) # improve stability
    diag(X) <- diag_exp
    return(X)
  } else{
    res <- X + L
    diag(res) <- diag(L) * exp(diag(X) / diag(L))
    return(res)
  }
}

log_lower <- function(K, L, identity = TRUE){
  if (identity){
    diag(K) <- log(diag(K))
    return(K)
  } else{
    res <- K - L
    diag(res) <- diag(L) * exp(diag(K) / diag(L))
    return(res)
  }
}

dist_lower <- function(X, Y){
  X_log <- log_lower(X, identity = TRUE)
  Y_log <- log_lower(Y, identity = TRUE)
  return(f_norm(X_log - Y_log, square = FALSE))
}

dist_lc <- function(X, Y){
  X_lower <- t(chol(X))
  Y_lower <- t(chol(Y))
  return(dist_lower(X_lower, Y_lower))
}

meanlist <- function(dat){
  x <- matrix(unlist(dat), ncol=length(dat))
  x_mean <- rowMeans(x)
  return(matrix(x_mean, ncol = sqrt(length(x_mean))))
}


frechet_mean_lc <- function(data, d){
  dat_lower <- lapply(data, function(x){
    t(chol(x))
  })
  dat_mean_lower <- meanlist(dat_lower)
  log_diag <- sapply(dat_lower, function(x){
    log(diag(x))
  })
  diag(dat_mean_lower) <- exp(rowMeans(log_diag))
  dat_mean <- dat_mean_lower %*% t(dat_mean_lower)
  return(dat_mean)
}

rSPD <- function(n, k, r, center){
  while (dim(data)[3] < n){
    data_ind <- rWishart(1, k, diag(rep(1/k, k)))
    dist <- dist_lc(data_ind[,,1], center)
    if ( dist <= r){
      data <- abind(data, data_ind)
    }
  }
  data <- lapply(seq(dim(data)[3]), function(x) data[, , x])
  return(data)
}

rSphere <- function(n, d){
  norm_pt <- matrix(rnorm(n * d, 0, 1), n, d)
  samp_list <- apply(norm_pt, 1, function(x){
    x / sqrt(sum(x^2))
  })
  t(samp_list)
}

rLap <- function(n, d, mu, sig){
  dir_list <- rSphere(n, d)
  len_list <- rgamma(n, shape = d, rate = 1)
  samp_list <- sapply(1:n, function(i){
    mu + dir_list[i, ] * len_list[i] * sig 
  })
  t(samp_list)
}

rSPD_dis_lc_single <- function(d, r, sig, p, type){
  
  
  length <- r + 1
  while (length >= r){
    if (type == "laplace"){
      noise <- rLap(1, d * (d+1) / 2, 0,  sig = sig)  
    } else if (type == "gauss"){
      noise <- rnorm(d * (d+1) / 2, 0, sd = sig)
    }
    
    length <- sqrt(sum(noise^2))
  }
  p_lower <-  log_lower(t(chol(p)), identity = TRUE)
  noise_mat <- matrix(0, d, d)
  noise_mat[lower.tri(noise_mat, diag = TRUE)] <- noise
  samp <- exp_lower(p_lower + noise_mat, identity = TRUE)
  samp <- samp %*% t(samp)
  return(samp)
}

rSPD_dis_lc <- function(n, d, r, sig, p, type = "laplace"){
  
  sample_list <- lapply(1:n, function(x) rSPD_dis_lc_single(d, r, sig, p, type))
  return(sample_list)
}

rSPD_dp_lc <- function(n, d, center, sigma, type = "gauss"){
  sample_list <- lapply(center, function(x){
    p_lower <-  log_lower(t(chol(x)), identity = TRUE)
    
    if (type == "gauss"){
      noise <- rnorm(d * (d + 1) / 2, 0 , sd = sigma)  
    } else if (type == "laplace"){
      noise <- rLap(1, d * (d+1) / 2, 0,  sig = sigma)  
    }
    
    noise_mat <- matrix(0, d, d)
    noise_mat[lower.tri(noise_mat, diag = TRUE)] <- noise
    
    samp <- exp_lower(p_lower + noise_mat, identity = TRUE)
    
    samp <- samp %*% t(samp)
  })
  return(sample_list)
}


