############################################################
# Global Fréchet Regression for Covariance Matrices
# Parallel Computing Script - Pure R (reads pre-generated data)
############################################################

library(frechet)
library(parallel)
library(doParallel)
library(foreach)

## --- Parameters ---
p <- 4L
q <- 3L  # Matrix dimension
sample_sizes <- c(250L, 500L, 1250L, 2500L)
num_simulations <- 200L
data_dir <- "covariance_data_rds"  # Directory with RDS files (run convert_to_rds.py first)

## --- Output directory ---
if (!dir.exists("covariance")) {
  dir.create("covariance")
}

## --- Check data directory exists ---
if (!dir.exists(data_dir)) {
  stop(sprintf("Data directory '%s' not found. Please run convert_to_rds.py first to convert .npy files to RDS format.", data_dir))
}

## --- Log-Cholesky distance function ---
log_cholesky_distance <- function(S1, S2) {
  S1 <- (S1 + t(S1)) / 2
  S2 <- (S2 + t(S2)) / 2
  reg <- 1e-8
  S1 <- S1 + reg * diag(nrow(S1))
  S2 <- S2 + reg * diag(nrow(S2))
  L1 <- chol(S1)
  L2 <- chol(S2)
  U1 <- t(L1)
  U2 <- t(L2)
  sUT1 <- U1
  sUT1[lower.tri(sUT1, diag = TRUE)] <- 0
  sUT2 <- U2
  sUT2[lower.tri(sUT2, diag = TRUE)] <- 0
  off_dist_sq <- sum((sUT1 - sUT2)^2)
  d1 <- diag(U1)
  d2 <- diag(U2)
  logD_dist_sq <- sum((log(d1) - log(d2))^2)
  sqrt(off_dist_sq + logD_dist_sq)
}

## --- Worker function for parallel execution ---
run_single_session <- function(seed, n) {
  ## Construct file path
  n_dir <- file.path(data_dir, paste0("n", n))
  rds_path <- file.path(n_dir, sprintf("seed%03d.rds", seed))
  
  ## Check if file exists
  if (!file.exists(rds_path)) {
    stop(sprintf("RDS file not found for n=%d, seed=%d: %s\nPlease run convert_to_rds.py first.", n, seed, rds_path))
  }
  
  ## Read data from RDS file
  tryCatch({
    data_list <- readRDS(rds_path)
    
    # Extract arrays from the list
    X_r <- as.array(data_list$X)  # n x 4 predictors
    M_array <- as.array(data_list$M)  # n x q x q matrices
    conditional_means_r <- as.array(data_list$C)  # n x q x q conditional means
    
    # Verify dimensions
    if (length(dim(X_r)) != 2) {
      stop(sprintf("X should be 2D but has %d dimensions", length(dim(X_r))))
    }
    if (length(dim(M_array)) != 3) {
      stop(sprintf("M should be 3D but has %d dimensions", length(dim(M_array))))
    }
    if (length(dim(conditional_means_r)) != 3) {
      stop(sprintf("C should be 3D but has %d dimensions", length(dim(conditional_means_r))))
    }
    
    n_samples <- dim(X_r)[1]
    
    # Verify dimensions match
    if (dim(M_array)[1] != n_samples || dim(conditional_means_r)[1] != n_samples) {
      stop(sprintf("Dimension mismatch: X has %d samples, M has %d, C has %d", 
                   n_samples, dim(M_array)[1], dim(conditional_means_r)[1]))
    }
  }, error = function(e) {
    stop(sprintf("Failed to load RDS file %s: %s", rds_path, e$message))
  })
  
  ## Data splitting: 40% train, 10% val, 200 test
  set.seed(seed)
  idx <- sample(seq_len(n_samples))
  n_train <- as.integer(0.4 * n_samples)
  n_val   <- as.integer(0.1 * n_samples)
  n_test  <- 200L   # Fixed test set size
  
  idx_train <- idx[1:n_train]
  idx_val   <- idx[(n_train + 1):(n_train + n_val)]
  idx_test  <- idx[(n_train + n_val + 1):(n_train + n_val + n_test)]
  
  X_train <- X_r[idx_train, , drop = FALSE]
  X_test  <- X_r[idx_test, , drop = FALSE]
  M_train <- M_array[idx_train, , , drop = FALSE]
  M_test_true <- conditional_means_r[idx_test, , , drop = FALSE]
  
  ## Convert to list format for GloCovReg
  M_list_train <- vector("list", length = n_train)
  for (i in seq_len(n_train)) {
    M_list_train[[i]] <- M_train[i, , ]
  }
  
  ## Global Covariance Regression
  res <- GloCovReg(
    x    = X_train,
    M    = M_list_train,
    xout = X_test,
    optns = list(
      corrOut = FALSE,
      metric  = "log_cholesky"
    )
  )
  
  M_pred <- res$Mout
  n_test <- length(M_pred)
  
  log_chol_distances <- numeric(n_test)
  for (i in seq_len(n_test)) {
    log_chol_distances[i] <- log_cholesky_distance(
      M_pred[[i]],
      M_test_true[i, , ]
    )
  }
  
  avg_log_chol_distance <- mean(log_chol_distances)
  
  cat(sprintf("[n=%d] Seed %d: %.6f\n", n, seed, avg_log_chol_distance))
  
  list(
    seed = seed,
    avg_log_chol_distance = avg_log_chol_distance,
    log_chol_distances = log_chol_distances
  )
}

## --- Parallel setup ---
num_cores <- detectCores() - 1
if (num_cores < 1) num_cores <- 1
cl <- makeCluster(num_cores)
registerDoParallel(cl)
cat("Using", num_cores, "parallel workers\n")

## --- Main execution loop ---
cat("=== Starting parallel GFR simulations ===\n")
cat("Reading data from:", data_dir, "\n")
cat("Sample sizes:", paste(sample_sizes, collapse = ", "), "\n")
cat("Simulations per size:", num_simulations, "\n\n")

for (n in sample_sizes) {
  cat("=== Running n =", n, "===\n")
  
  results <- foreach(seed = seq_len(num_simulations),
                     .packages = c("frechet"),
                     .combine = 'c') %dopar% {
    list(run_single_session(seed, n))
  }
  
  all_avg_distances <- sapply(results, function(x) x$avg_log_chol_distance)
  all_seeds         <- sapply(results, function(x) x$seed)
  
  results_df <- data.frame(
    seed = all_seeds,
    avg_log_chol_distance = all_avg_distances
  )
  
  filename <- file.path("covariance", paste0(n, "_cov.csv"))
  write.csv(results_df, filename, row.names = FALSE)
  
  cat("  Mean:", round(mean(all_avg_distances), 6), "\n")
  cat("  Std :", round(sd(all_avg_distances), 6), "\n")
  cat("  Saved:", filename, "\n\n")
}

# Stop parallel cluster
stopCluster(cl)

cat("=== All simulations completed ===\n")
