############################################################
# Index Fréchet Regression (IFR) for Covariance Matrices
# Parallel Computing Script - Pure R (reads pre-generated data)
############################################################

library(frechet)
library(parallel)
library(doParallel)
library(foreach)
library(Matrix)

# Source the SIdxCovReg function
source("SIdxCovReg.R")

## --- Parameters ---
p <- 4L
q <- 3L  # Matrix dimension
sample_sizes <- c(250L, 500L, 1250L, 2500L)
num_simulations <- 200L
data_dir <- "covariance_data_rds"  # Directory with RDS files

## --- Output directory ---
if (!dir.exists("covariance")) {
  dir.create("covariance")
}

## --- Check data directory exists ---
if (!dir.exists(data_dir)) {
  stop(sprintf("Data directory '%s' not found. Please run convert_to_rds.py first to convert data to RDS format.", data_dir))
}

## --- Log-Cholesky distance function ---
log_cholesky_distance <- function(S1, S2) {
  S1 <- (S1 + t(S1)) / 2
  S2 <- (S2 + t(S2)) / 2
  reg <- 1e-8
  S1 <- S1 + reg * diag(nrow(S1))
  S2 <- S2 + reg * diag(nrow(S2))
  L1 <- chol(S1)
  L2 <- chol(S2)
  U1 <- t(L1)
  U2 <- t(L2)
  sUT1 <- U1
  sUT1[lower.tri(sUT1, diag = TRUE)] <- 0
  sUT2 <- U2
  sUT2[lower.tri(sUT2, diag = TRUE)] <- 0
  off_dist_sq <- sum((sUT1 - sUT2)^2)
  d1 <- diag(U1)
  d2 <- diag(U2)
  logD_dist_sq <- sum((log(d1) - log(d2))^2)
  sqrt(off_dist_sq + logD_dist_sq)
}

## --- Worker function for parallel execution ---
run_single_session <- function(seed, n) {
  ## Construct file path
  n_dir <- file.path(data_dir, paste0("n", n))
  rds_path <- file.path(n_dir, sprintf("seed%03d.rds", seed))
  
  ## Check if file exists
  if (!file.exists(rds_path)) {
    stop(sprintf("RDS file not found for n=%d, seed=%d: %s\nPlease run convert_to_rds.py first.", 
                 n, seed, rds_path))
  }
  
  ## Read data from RDS file
  tryCatch({
    data_list <- readRDS(rds_path)
    
    # Extract arrays from the list
    X_r <- as.array(data_list$X)  # n x 4 predictors
    M_array <- as.array(data_list$M)  # n x q x q matrices
    conditional_means_r <- as.array(data_list$C)  # n x q x q conditional means
    
    # Verify dimensions
    if (length(dim(X_r)) != 2) {
      stop(sprintf("X should be 2D but has %d dimensions", length(dim(X_r))))
    }
    if (length(dim(M_array)) != 3) {
      stop(sprintf("M should be 3D but has %d dimensions", length(dim(M_array))))
    }
    if (length(dim(conditional_means_r)) != 3) {
      stop(sprintf("C should be 3D but has %d dimensions", length(dim(conditional_means_r))))
    }
    
    n_samples <- dim(X_r)[1]
    
    # Verify dimensions match
    if (dim(M_array)[1] != n_samples || dim(conditional_means_r)[1] != n_samples) {
      stop(sprintf("Dimension mismatch: X has %d samples, M has %d, C has %d", 
                   n_samples, dim(M_array)[1], dim(conditional_means_r)[1]))
    }
    
    # Try to get theta_true if available, otherwise set to NULL
    theta_true_r <- NULL
    if (!is.null(data_list$theta_true)) {
      theta_true_r <- as.array(data_list$theta_true)
    }
  }, error = function(e) {
    stop(sprintf("Failed to load RDS file %s: %s", rds_path, e$message))
  })
  
  ## Data splitting: 40% train, 10% val, 200 test
  set.seed(seed)
  idx <- sample(seq_len(n_samples))
  n_train <- as.integer(0.4 * n_samples)
  n_val   <- as.integer(0.1 * n_samples)
  n_test  <- 200L   # Fixed test set size
  
  train_idx <- idx[1:n_train]
  val_idx   <- idx[(n_train + 1):(n_train + n_val)]
  test_idx  <- idx[(n_train + n_val + 1):(n_train + n_val + n_test)]
  
  X_train <- X_r[train_idx, , drop = FALSE]
  M_train <- M_array[train_idx, , , drop = FALSE]
  X_test <- X_r[test_idx, , drop = FALSE]
  M_test_true <- conditional_means_r[test_idx, , , drop = FALSE]
  
  ## Convert to array format for SIdxCovReg (q x q x n)
  Min_train <- array(0, dim = c(q, q, n_train))
  for (i in seq_len(n_train)) {
    Min_train[, , i] <- M_train[i, , ]
  }
  
  ## Run Single Index Covariance Regression to get estimated theta
  sidx_result <- SIdxCovReg(
    xin = X_train,
    Min = Min_train,
    iter = 500,
    verbose = FALSE
  )
  est_theta <- sidx_result$est
  
  ## Use estimated theta to get projections for test set
  xout_test_proj <- X_test %*% est_theta  # 1D projections for test
  
  ## Use CovDirLocLin to get predicted covariance matrices for test points
  M_pred <- array(0, dim = c(q, q, n_test))
  for (i in seq_len(n_test)) {
    M_pred[, , i] <- CovDirLocLin(
      xin = X_train,
      Min = Min_train,
      direc = est_theta,
      xout = xout_test_proj[i],
      bw = sidx_result$bw,
      ker = ker_gauss,
      lower = -Inf,
      upper = Inf
    )
  }
  
  ## Calculate log-Cholesky distances
  log_chol_distances <- numeric(n_test)
  for (i in seq_len(n_test)) {
    log_chol_distances[i] <- log_cholesky_distance(
      M_pred[, , i],
      M_test_true[i, , ]
    )
  }
  
  avg_log_chol_distance <- mean(log_chol_distances)
  
  cat(sprintf("[n=%d] Seed %d: %.6f\n", n, seed, avg_log_chol_distance))
  
  list(
    seed = seed,
    avg_log_chol_distance = avg_log_chol_distance,
    log_chol_distances = log_chol_distances,
    theta_true = theta_true_r,
    est_theta = est_theta,
    n = n
  )
}

## --- Parallel setup ---
num_cores <- detectCores() - 1
if (num_cores < 1) num_cores <- 1
cl <- makeCluster(num_cores)
registerDoParallel(cl)
cat("Using", num_cores, "parallel workers\n")

## --- Main execution loop ---
cat("=== Starting parallel IFR simulations for covariance matrices ===\n")
cat("Reading data from:", data_dir, "\n")
cat("Sample sizes:", paste(sample_sizes, collapse = ", "), "\n")
cat("Simulations per size:", num_simulations, "\n\n")

for (n in sample_sizes) {
  cat("=== Running n =", n, "===\n")
    
  results <- foreach(seed = seq_len(num_simulations),
                     .packages = c("frechet", "Matrix"),
                      .combine = 'c') %dopar% {
    list(run_single_session(seed, n))
    }
    
  all_avg_distances <- sapply(results, function(x) x$avg_log_chol_distance)
    all_seeds <- sapply(results, function(x) x$seed)
    est_thetas <- do.call(rbind, lapply(results, function(x) x$est_theta))
    
  results_df <- data.frame(
      seed = all_seeds,
    avg_log_chol_distance = all_avg_distances
    )
    
  filename <- file.path("covariance", paste0(n, "_ifr.csv"))
  write.csv(results_df, filename, row.names = FALSE)
    
  theta_filename <- file.path("covariance", paste0(n, "_ifr_thetas.csv"))
    write.csv(est_thetas, theta_filename, row.names = FALSE)
    
  cat("  Mean Log-Cholesky distance:", round(mean(all_avg_distances), 6), "\n")
  cat("  Std Log-Cholesky distance:", round(sd(all_avg_distances), 6), "\n")
  cat("  Saved:", filename, "\n")
  cat("  Thetas saved:", theta_filename, "\n\n")
}

# Stop parallel cluster
stopCluster(cl)

cat("=== All simulations completed ===\n")
