library(abind)
library(expm)
library(MASS)
library(plotly)
library(Rfast)
library(stats)
library(RSpectra)


# ======================== vectorization ========================


vecd <- function(m){
  return(c(diag(m), sqrt(2)* m[upper.tri(m, diag = FALSE)]))
}


vecd_inverse <- function(u){
  m <- sqrt(1 + 8 * length(u))/ 2 - 0.5
  mat <- matrix(0, m, m)
  diag(mat) <- u[1:m]
  mat[upper.tri(mat, diag = FALSE)] <- u[-(1:m)] / sqrt(2)
  mat[lower.tri(mat)] <- t(mat)[lower.tri(mat)]
  return(mat)
}


# ======================== Basic functions on Riemannian manifold ========================

tr_fn <- function(mat){
  return(sum(diag(mat)))
}

euclid_dist_fn <- function(x, y){
  return(sqrt(sum((x - y)^2)))
}

distance_fn <- function(p, q){
  return(sqrt(square_distance_fn(p, q)))
}


square_distance_fn <- function(p, x){
  y_eigen <- NA
  d <- dim(p)[1]
  p_eigen <- eigen(p, symmetric = TRUE)
  # g <- p_eigen$vector %*% diag(sqrt(p_eigen$values))
  g_inverse <- diag(1 / sqrt(p_eigen$values)) %*% t(p_eigen$vectors)
  y <- g_inverse %*% x %*% t(g_inverse)
  tryCatch(y_eigen <- eigen(y, symmetric = TRUE, only.values = TRUE),
           error = function(c){
             print(p_eigen$values)
             print(p_eigen$vectors)
             print(p)
             print(y)
           }
  )
  suppressWarnings(if(is.na(y_eigen)[1]){
    return(NA)
  } else if(is.nan(sum(log(y_eigen$values)^2))){
    print("y_eigen")
    print(y_eigen$values)
    print("p_eigen")
    print(p_eigen$values)
    print("x_eigen")
    print(eigen(x, symmetric = TRUE, only.values = TRUE)$values)
    return(NA)
  } else{
    return( sum(log(y_eigen$values)^2) )  
  })
}

exp_fn <- function(p, v){
  p_sqrt <- expm::sqrtm(p)
  p_sqrt_inv <- spdinv(p_sqrt)
  return( p_sqrt %*% expm( p_sqrt_inv %*% v %*% p_sqrt_inv ) %*% p_sqrt )
}
# 
log_fn <- function(p, q){
  q_sqrt <- expm::sqrtm(q)
  q_sqrt_inv <- spdinv(q_sqrt)
  return( q_sqrt %*% logm( q_sqrt_inv %*% p %*% q_sqrt_inv ) %*% q_sqrt )
}

# ======================== Sampling from Multivariate Laplace Distribution ========================

rSphere <- function(n, d){
  norm_pt <- matrix(rnorm(n * d, 0, 1), n, d)
  samp_list <- apply(norm_pt, 1, function(x){
    x / sqrt(sum(x^2))
  })
  t(samp_list)
}

rLap <- function(n, d, mu, sig){
  dir_list <- rSphere(n, d)
  len_list <- rgamma(n, shape = d, rate = 1)
  samp_list <- sapply(1:n, function(i){
    mu + dir_list[i, ] * len_list[i] * sig 
  })
  t(samp_list)
}


# ======================== sampling from Wishart distribution ========================

rSPD <- function(n, k, r){
  data <- array(0, c(k, k, 0))
  while (dim(data)[3] < n){
    data_ind <- rWishart(1, k, diag(rep(1/k, k)))
    dist_sq <- sum( expm::logm(data_ind[,,1], method = "Eigen") ^ 2)
    if ( dist_sq <= r^2){
      data <- abind(data, data_ind)
    }
  }
  data <- lapply(seq(dim(data)[3]), function(x) data[, , x])
  return(data)
}

# === testing for rSPD(): ===
# test <- rSPD(100, 2, 1.5)
# test <- t(sapply(test, function(x) x[lower.tri(x, diag = TRUE)]))
# test_plot <- scatterplot3d::scatterplot3d(test)


# ======================== sampling from Riemannian Laplace distributions ========================

sinh_fn <- function(x){
  d <- length(x)
  res <- 1
  for (i in 2:d){
    for (j in 1:(i - 1)){
      res <- res * sinh(abs(x[i] - x[j])/2)
    }
  }
  return(res)
}

sinh_ratio_fn <- function(x_new, x){
  d <- length(x)
  res <- 1
  for (i in 2:d){
    for (j in 1:(i - 1)){
      res <- res * sinh(abs(x_new[i] - x_new[j])/2) / sinh(abs(x[i] - x[j])/2) 
    }
  }
  return(res)
}

density_ratio_fn <- function(x, x_new, sigma, type = "gauss"){
  # print("========================= density_ratio_fn =========================")
  sinh_ratio <- sinh_fn(x_new) / sinh_fn(x)
  # sinh_ratio <- sinh_ratio_fn(x_new, x)
  if (type == "gauss"){
    return(exp( (- sum(x_new^2) + sum(x^2) ) / (2 * sigma^2) ) * sinh_ratio)
  } else if (type == "laplace"){
    return(exp( (- sqrt(sum(x_new^2)) + sqrt(sum(x^2)) ) / sigma ) * sinh_ratio)
  }
}

new_pt_fun <- function(x, d, sigma, type){
  x_new <- NA
  suppressWarnings( while(is.na(x_new)[1]){
    x_new <- mvrnorm(1, x, (sigma / 2^d )^2 * diag(d))
    p <- density_ratio_fn(x, x_new, sigma, type)
    if (is.nan(p)){
      p <- 1
    }
    if ( runif(1) > min(p, 1) ){
      x_new <- NA
    }
  } )
  return(x_new)
}

mh_sample_fn <- function(n, n_burn, d, sigma, type){
  
  n_total <- n + n_burn
  samp_list <- c()
  
  x <- rep(0, d) # initialize
  samp_list <- rbind(samp_list, x)
  
  while(nrow(samp_list) < n_total){
    x_new <- new_pt_fun(x, d, sigma, type)
    samp_list <- rbind(samp_list, x_new)
    x <- x_new
  }
  
  return( samp_list[(n_burn + 1):n_total, ])
}


# === testing for mh_sample_fn(): ===
# test <- mh_sample_fn(2000, 5000, 2, 2, "gauss")
# plot(test[,1], test[,2])
# plot(exp(test[,1]), exp(test[,2]))
# hist(apply(test, 1, sum), breaks = 25, freq = FALSE)
# points(seq(-3,3,0.1), dnorm(seq(-3,3,0.1)), type = "l")


rOrtho_mat <- function(n, d){
  res <- lapply(1:n, function(x){
    gauss_mat <- matrix(rnorm(d * d), d, d)
    QR <- qr(gauss_mat) 
    Q <- qr.Q(QR)
    R <- qr.R(QR)
    for (i in 1:d){
      if(R[i,i] < 0){
        Q[i, ] <- -1 * Q[i, ]
      }
    }
    Q
  })
  return(res)
}

# === testing for rOrtho_mat(): ===
# test <- rOrtho_mat(300, 2)
# test <- t(sapply(test, function(x) x[lower.tri(x, diag = TRUE)]))
# test_plot <- scatterplot3d::scatterplot3d(test)

# for sampling from Riemannian Laplace distributino
rSPD_dist <- function(n, d, mu, sigma, type, n_burn = 3000){
  ortho_mat <- rOrtho_mat(n, d)
  # print(ortho_mat)
  p_samp <- mh_sample_fn(n, n_burn, d, sigma, type)
  p_samp <- matrix(p_samp, n, d)
  samp_list <- lapply(1:n, function(x){
    mu_sqrt <- expm::sqrtm(mu)
    t(mu_sqrt) %*% t(ortho_mat[[x]]) %*% diag(exp(p_samp[x, ])) %*% ortho_mat[[x]] %*% mu_sqrt
    # t(ortho_mat[[x]]) %*% diag(exp(p_samp[x, ])) %*% ortho_mat[[x]]
  })
  return(samp_list)
}

# samp <- rSPD(500, 2, 1.5)
# mean <- frechet_mean(samp)
# test <- rSPD_dist(500, 2, mean, 0.5, "gauss")
# t(sapply(test, function(x) eigen(x)$values))
# test <- t(sapply(test, function(x) x[lower.tri(x, diag = TRUE)]))
# test <- data.frame(test)
# test_plot <- plot_ly(test, x = ~X1, y = ~X2, z = ~X3)
# test_plot <- test_plot %>% add_markers()
# test_plot <- test_plot %>% layout(scene = list(xaxis = list(title = 'X'),
#                                                yaxis = list(title = 'Y'),
#                                                zaxis = list(title = 'Z')))
# # test_plot <- test_plot %>% add_trace(x = 1, y = 0, z = 1, name = 'center', mode = 'markers')
# test_plot <- test_plot %>% add_trace(x = mean[1,1], y = mean[2,1], z = mean[2,2], name = 'center', mode = 'markers')
# test_plot


# ======================== frechet mean computation ========================

frechet_mean <- function(data, d, n_iter = 500, step_size = 0.5, lambda = 1e-5){
  
  mean_old <- data[[sample(1:length(data), 1)]]
  diff <- 1
  iter <- 1
  # print(mean_old)
  while (diff > lambda & iter <= n_iter){
    
    v_new <- matrix(apply(sapply(data, function(x){
      log_fn(x, mean_old)
    }), 1, mean), d, d)
    
    # print(v_new)
    mean_new <- exp_fn(mean_old, step_size * v_new)
    
    diff <- distance_fn(mean_new, mean_old)
    mean_old <- mean_new
    
    iter <- iter + 1
    # print(c(diff, mean_old, iter))
  }
  
  return(mean_old)
}

# === testing for frechet_mean(): ===
# test <- rSPD(100, 2, 1.5)
# mean <- frechet_mean(test)
# test <- t(sapply(test, function(x) x[lower.tri(x, diag = TRUE)]))
# test_plot <- scatterplot3d::scatterplot3d(test)
# test_plot$points3d(mean[1,1], mean[2,1], mean[2,2], col = "red")




