rm(list =ls())
library(reticulate)
library(imager)


cov_descrip <- function(im, eta, k){
  
  # im <- im / 255 # normalizing the intensity value to be between 0 and 1
  im <- as.cimg(im)
  im_x <- get_gradient(im, "x")[[1]]
  im_y <- get_gradient(im, "y")[[1]]
  im_xx <- get_gradient(im_x, "x")[[1]][,,,1]
  im_yy <- get_gradient(im_y, "y")[[1]][,,,1]
  
  im <- im[,,,1]
  im_x <- im_x[,,,1]
  im_y <- im_y[,,,1]
  
  feat <- array(NA, c(28,28,k))
  
  if (k == 9){
    feat[,,1] <- matrix(rep(1:28, 28), 28, 28)
    feat[,,2] <- matrix(rep(1:28, each = 28), 28, 28)
    feat[,,3] <- im
    feat[,,4] <- abs(im_x)
    feat[,,5] <- abs(im_y)
    feat[,,6] <- abs(im_xx)
    feat[,,7] <- abs(im_yy)
    feat[,,8] <- sqrt(im_x^2 + im_y^2)
    feat[,,9] <- atan(abs(im_x) / abs(im_y))  
  } else if (k == 8){
    # feat[,,1] <- im
    feat[,,1] <- abs(im_x)
    feat[,,2] <- abs(im_y)
    feat[,,3] <- abs(im_xx)
    feat[,,4] <- abs(im_yy)
  } else if (k == 5){
    feat[,,1] <- im
    feat[,,2] <- abs(im_x)
    feat[,,3] <- abs(im_y)
    feat[,,4] <- abs(im_xx)
    feat[,,5] <- abs(im_yy)
  }
  
  feat <- array(feat, dim = c(28 * 28, k )) # flatten to a matrix of dimension 28^2 x 9
  
  cov <- cov(feat, use = "complete.obs") * (28 * 28 - 1)/(28 * 28) + eta * diag(1, k, k)
  
  return(cov)
}

data_path <- ""
path <- ""
source(paste0(path, "spd_functions.R"))

use_virtualenv("")
np <- import("numpy")
npz1 <- np$load(paste0(data_path, "octmnist.npz"))
npz1$files
oct_file<- npz1$f[["train_images"]]

class <- 2
eta <- 1e-6
k <- 5

class_ind <- which(npz1$f[["train_labels"]] == class)
# class_ind <- class_ind[1:10] # small sample for testing
n <- length(class_ind)
spd_class <- array(NA, dim = c(k, k, n))


no_cores <- parallel::detectCores() - 2
cl <- parallel::makeCluster(no_cores)
parallel::clusterExport(cl, c("class_ind", "cov_descrip", "spd_class", "oct_file", "eta", "k"))
parallel::clusterEvalQ(cl, {
  library(reticulate)
  library(imager)
})

spd_class <- array(parallel::parSapply(cl, class_ind, function(x){
  cov_descrip(oct_file[x,,], eta, k)
}), dim = c(k, k, n))


dat <- lapply(seq(dim(spd_class)[3]), function(x) spd_class[ , , x])
print(Sys.time())
true_p <- frechet_mean(dat, k, lambda = 1e-3)
fix_point <- diag(rep(1, k))
print(distance_fn(true_p, fix_point) * k * (k+1) /2)
print(Sys.time())

saveRDS(spd_class,file = paste0(data_path, "spd_train_", class, "_", k, ".rds"))

write.csv(true_p, file = paste0(data_path, "octmnist_frechet_mean_", class, "_", k, ".csv"))

