############################################################
# Index Fréchet Regression (IFR) for Network/Graph Laplacians
# Parallel Computing Script - Pure R (reads pre-generated data)
# Uses SIdxNetReg for Single Index Network Regression
############################################################

library(frechet)
library(parallel)
library(doParallel)
library(foreach)
library(Matrix)

# Source the SIdxNetReg function
source("SIdxNetReg.R")

## --- Parameters ---
p <- 4L
q <- 10L  # Matrix dimension
sample_sizes <- c(500L, 1250L, 2500L)
num_simulations <- 200L
data_dir <- "covariance_data_rds"  # Directory with RDS files

## --- Output directory ---
if (!dir.exists("covariance")) {
  dir.create("covariance")
}

## --- Check data directory exists ---
if (!dir.exists(data_dir)) {
  stop(sprintf("Data directory '%s' not found. Please run convert_to_rds.py first to convert data to RDS format.", data_dir))
}

## --- Frobenius distance function ---
frobenius_distance <- function(S1, S2) {
  # Ensure matrices are symmetric
  S1 <- (S1 + t(S1)) / 2
  S2 <- (S2 + t(S2)) / 2
  # Calculate Frobenius norm of the difference
  sqrt(sum((S1 - S2)^2))
}

## --- Worker function for parallel execution ---
run_single_session <- function(seed, n) {
  ## Construct file path
  n_dir <- file.path(data_dir, paste0("n", n))
  rds_path <- file.path(n_dir, sprintf("seed%03d.rds", seed))
  
  ## Check if file exists
  if (!file.exists(rds_path)) {
    stop(sprintf("RDS file not found for n=%d, seed=%d: %s\nPlease run convert_to_rds.py first.", 
                 n, seed, rds_path))
  }
  
  ## Read data from RDS file
  tryCatch({
    data_list <- readRDS(rds_path)
    
    # Extract arrays from the list
    X_r <- as.array(data_list$X)  # n x 4 predictors
    M_array <- as.array(data_list$M)  # n x q x q matrices
    conditional_means_r <- as.array(data_list$C)  # n x q x q conditional means
    
    # Verify dimensions
    if (length(dim(X_r)) != 2) {
      stop(sprintf("X should be 2D but has %d dimensions", length(dim(X_r))))
    }
    if (length(dim(M_array)) != 3) {
      stop(sprintf("M should be 3D but has %d dimensions", length(dim(M_array))))
    }
    if (length(dim(conditional_means_r)) != 3) {
      stop(sprintf("C should be 3D but has %d dimensions", length(dim(conditional_means_r))))
    }
    
    n_samples <- dim(X_r)[1]
    
    # Verify dimensions match
    if (dim(M_array)[1] != n_samples || dim(conditional_means_r)[1] != n_samples) {
      stop(sprintf("Dimension mismatch: X has %d samples, M has %d, C has %d", 
                   n_samples, dim(M_array)[1], dim(conditional_means_r)[1]))
    }
    
    # Try to get theta_true if available, otherwise set to NULL
    theta_true_r <- NULL
    if (!is.null(data_list$theta_true)) {
      theta_true_r <- as.array(data_list$theta_true)
    }
  }, error = function(e) {
    stop(sprintf("Failed to load RDS file %s: %s", rds_path, e$message))
  })
  
  ## Data splitting: 40% train, 10% val, 50% test
  set.seed(seed)
  idx <- sample(seq_len(n_samples))
  n_train <- as.integer(0.4 * n_samples)
  n_val   <- as.integer(0.1 * n_samples)
  n_test  <- 200L   # Fixed test set size
  
  train_idx <- idx[1:n_train]
  val_idx   <- idx[(n_train + 1):(n_train + n_val)]
  test_idx  <- idx[(n_train + n_val + 1):(n_train + n_val + n_test)]
  
  X_train <- X_r[train_idx, , drop = FALSE]
  M_train <- M_array[train_idx, , , drop = FALSE]
  X_test <- X_r[test_idx, , drop = FALSE]
  M_test_true <- conditional_means_r[test_idx, , , drop = FALSE]
  
  ## Convert to array format for SIdxNetReg (q x q x n)
  Min_train <- array(0, dim = c(q, q, n_train))
  for (i in seq_len(n_train)) {
    Min_train[, , i] <- M_train[i, , ]
  }
  
  ## Run Single Index Network Regression to get estimated theta
  sidx_result <- SIdxNetReg(
    xin = X_train,
    Min = Min_train,
    iter = 50,
    verbose = FALSE
  )
  est_theta <- sidx_result$est
  
  ## Use estimated theta to get projections for test set
  xout_test_proj <- X_test %*% est_theta  # 1D projections for test
  
  ## Use NetDirLocLin to get predicted graph laplacians for test points
  M_pred <- array(0, dim = c(q, q, n_test))
  for (i in seq_len(n_test)) {
    net_result <- NetDirLocLin(
      xin = X_train,
      Min = Min_train,
      direc = est_theta,
      xout = xout_test_proj[i],
      bw = sidx_result$bw
    )
    M_pred[, , i] <- net_result$predict[[1]]
  }
  
  ## Calculate Frobenius distances
  frob_distances <- numeric(n_test)
  for (i in seq_len(n_test)) {
    frob_distances[i] <- frobenius_distance(
      M_pred[, , i],
      M_test_true[i, , ]
    )
  }
  
  avg_frob_distance <- mean(frob_distances)
  
  cat(sprintf("[n=%d] Seed %d: %.6f\n", n, seed, avg_frob_distance))
  
  list(
    seed = seed,
    avg_frob_distance = avg_frob_distance,
    frob_distances = frob_distances,
    theta_true = theta_true_r,
    est_theta = est_theta,
    n = n
  )
}

## --- Parallel setup ---
if (num_simulations > 1) {
  num_cores <- detectCores() - 1
  if (num_cores < 1) num_cores <- 1
  cl <- makeCluster(num_cores)
  registerDoParallel(cl)
  cat("Using", num_cores, "parallel workers\n")
} else {
  cat("Running sequentially (1 simulation)\n")
}

## --- Main execution loop ---
cat("=== Starting IFR simulations for network/graph laplacians ===\n")
cat("Reading data from:", data_dir, "\n")
cat("Sample sizes:", paste(sample_sizes, collapse = ", "), "\n")
cat("Simulations per size:", num_simulations, "\n\n")

for (n in sample_sizes) {
  cat("=== Running n =", n, "===\n")
  
  if (num_simulations > 1) {
    results <- foreach(seed = seq_len(num_simulations),
                       .packages = c("frechet", "Matrix"),
                       .combine = 'c') %dopar% {
      list(run_single_session(seed, n))
    }
  } else {
    # Sequential execution for single simulation
    results <- list(run_single_session(1L, n))
  }
  
  all_avg_distances <- sapply(results, function(x) x$avg_frob_distance)
  all_seeds <- sapply(results, function(x) x$seed)
  est_thetas <- do.call(rbind, lapply(results, function(x) x$est_theta))
  
  results_df <- data.frame(
    seed = all_seeds,
    avg_frobenius_distance = all_avg_distances
  )
  
  filename <- file.path("covariance", paste0(n, "_cov_SIFR.csv"))
  write.csv(results_df, filename, row.names = FALSE)
  
  theta_filename <- file.path("covariance", paste0(n, "_ifr_thetas.csv"))
  write.csv(est_thetas, theta_filename, row.names = FALSE)
  
  cat("  Mean Frobenius distance:", round(mean(all_avg_distances), 6), "\n")
  if (num_simulations > 1) {
    cat("  Std Frobenius distance:", round(sd(all_avg_distances), 6), "\n")
  }
  cat("  Saved:", filename, "\n")
  cat("  Thetas saved:", theta_filename, "\n\n")
}

# Stop parallel cluster if it was created
if (num_simulations > 1) {
  stopCluster(cl)
}

cat("=== All simulations completed ===\n")
