############################################################
# Deep Fréchet Regression (DFR) - Parallel Computing Script
# Pure R (reads pre-generated data from RDS files)
# Includes: reticulate cluster init + torch check
############################################################

## ---------------------------------------------------------
## 0) Choose the exact Python reticulate should use
## ---------------------------------------------------------
# Try to use RETICULATE_PYTHON environment variable, or detect automatically
py <- Sys.getenv("RETICULATE_PYTHON", unset = "")
if (py == "") {
  # Try common Python paths
  python_paths <- c(
    "/usr/bin/python3",
    "/usr/bin/python",
    "/opt/conda/bin/python",
    "/opt/miniconda3/bin/python",
    "/usr/local/bin/python3",
    "/usr/local/bin/python"
  )
  python_found <- FALSE
  for (path in python_paths) {
    if (file.exists(path)) {
      py <- path
      python_found <- TRUE
      break
    }
  }
  if (!python_found) {
    # Fall back to reticulate's default
    py <- py_config()$python
  }
}
Sys.setenv(RETICULATE_PYTHON = py)

library(frechet)
library(vegan)
library(parallel)
library(doParallel)
library(foreach)
library(reticulate)

cat("Master reticulate python:\n")
print(py_config()$python)

## ---------------------------------------------------------
## 1) Source all required R functions
## ---------------------------------------------------------
source("code/DFR.R")
source("code/lrem.R")
source("code/lnr.R")
source("code/lcr.R")
source("code/lcm.R")
source("code/kerFctn.R")

## ---------------------------------------------------------
## 2) Parameters
## ---------------------------------------------------------
p <- 4L
q <- 10L
sample_sizes <- c(500L, 1250L, 2500L)
num_simulations <- 200L
data_dir <- "covariance_data_rds"
script_dir <- getwd()

## ---------------------------------------------------------
## 3) Output / checks
## ---------------------------------------------------------
if (!dir.exists("covariance")) dir.create("covariance")

if (!dir.exists(data_dir)) {
  stop(sprintf(
    "Data directory '%s' not found. Please run convert_to_rds.py first to convert data to RDS format.",
    data_dir
  ))
}

## ---------------------------------------------------------
## 4) Helpers
## ---------------------------------------------------------
frobenius_distance <- function(S1, S2) {
  S1 <- (S1 + t(S1)) / 2
  S2 <- (S2 + t(S2)) / 2
  sqrt(sum((S1 - S2)^2))
}

## ---------------------------------------------------------
## 5) Worker function
## ---------------------------------------------------------
run_single_session <- function(seed, n) {
  n_dir <- file.path(data_dir, paste0("n", n))
  rds_path <- file.path(n_dir, sprintf("seed%03d.rds", seed))

  if (!file.exists(rds_path)) {
    stop(sprintf(
      "RDS file not found for n=%d, seed=%d: %s\nPlease run convert_to_rds.py first.",
      n, seed, rds_path
    ))
  }

  data_list <- tryCatch(
    readRDS(rds_path),
    error = function(e) stop(sprintf("Failed to load RDS file %s: %s", rds_path, e$message))
  )

  X_r <- as.array(data_list$X)
  M_array <- as.array(data_list$M)
  conditional_means_r <- as.array(data_list$C)

  if (length(dim(X_r)) != 2) stop(sprintf("X should be 2D but has %d dimensions", length(dim(X_r))))
  if (length(dim(M_array)) != 3) stop(sprintf("M should be 3D but has %d dimensions", length(dim(M_array))))
  if (length(dim(conditional_means_r)) != 3) stop(sprintf("C should be 3D but has %d dimensions", length(dim(conditional_means_r))))

  n_samples <- dim(X_r)[1]
  if (dim(M_array)[1] != n_samples || dim(conditional_means_r)[1] != n_samples) {
    stop(sprintf(
      "Dimension mismatch: X has %d samples, M has %d, C has %d",
      n_samples, dim(M_array)[1], dim(conditional_means_r)[1]
    ))
  }

  ## Split: 40% train, 10% val, 200 test
  set.seed(seed)
  idx <- sample(seq_len(n_samples))
  n_train <- as.integer(0.4 * n_samples)
  n_val   <- as.integer(0.1 * n_samples)
  n_test  <- 200L   # Fixed test set size

  train_idx <- idx[1:n_train]
  test_idx  <- idx[(n_train + n_val + 1):(n_train + n_val + n_test)]

  X_train_matrix <- X_r[train_idx, , drop = FALSE]
  M_train <- M_array[train_idx, , , drop = FALSE]
  X_test_matrix <- X_r[test_idx, , drop = FALSE]
  M_test_true <- conditional_means_r[test_idx, , , drop = FALSE]

  X_train <- as.data.frame(X_train_matrix)
  colnames(X_train) <- c("V1", "V2", "V3", "V4")

  X_test <- as.data.frame(X_test_matrix)
  colnames(X_test) <- c("V1", "V2", "V3", "V4")

  M_list_train <- vector("list", length = n_train)
  for (i in seq_len(n_train)) {
    M_list_train[[i]] <- M_train[i, , ]
  }

  res_dfr <- DFR(
    y = M_list_train,
    x = X_train,
    xout = X_test,
    optns = list(
      type = "network",
      manifold = list(
        method = "isomap",
        k = ifelse(n <= 100, 20L, as.integer(0.1 * n))
      ),
      r = 2,
      layer = 4,
      hidden = ifelse(n <= 100, 32L, 64L),
      dropout = 0.3,
      lr = 0.0005,
      num_epochs = 2000,
      seed = seed
    )
  )

  M_pred <- res_dfr$yPred

  n_test <- length(test_idx)
  frob_distances <- numeric(n_test)

  for (i in seq_len(n_test)) {
    if (i <= length(M_pred)) {
      if (is.matrix(M_pred[[i]])) {
        M_pred_i <- M_pred[[i]]
      } else if (is.list(M_pred[[i]]) && length(M_pred[[i]]) >= 1 && is.matrix(M_pred[[i]][[1]])) {
        M_pred_i <- M_pred[[i]][[1]]
      } else {
        M_pred_i <- matrix(M_pred[[i]], nrow = q, ncol = q)
      }

      M_true_i <- M_test_true[i, , ]
      frob_distances[i] <- frobenius_distance(M_pred_i, M_true_i)
    } else {
      frob_distances[i] <- NA_real_
    }
  }

  avg_frob_distance <- mean(frob_distances, na.rm = TRUE)
  cat(sprintf("[n=%d] Seed %d: %.6f\n", n, seed, avg_frob_distance))

  list(seed = seed, avg_frobenius_distance = avg_frob_distance)
}

## ---------------------------------------------------------
## 6) Parallel setup
## ---------------------------------------------------------
num_cores <- detectCores() - 1
if (is.na(num_cores) || num_cores < 1) num_cores <- 1

cl <- makeCluster(num_cores)
registerDoParallel(cl)
cat("Using", num_cores, "parallel workers\n")

clusterExport(cl, c("script_dir", "py"))

clusterEvalQ(cl, {
  Sys.setenv(RETICULATE_PYTHON = py)
  library(reticulate)
  setwd(script_dir)
  use_python(py, required = TRUE)

  # Check torch
  ok <- FALSE
  err <- NULL
  tryCatch({
    import("torch", delay_load = FALSE)
    ok <- TRUE
  }, error = function(e) {
    err <<- e$message
  })
  if (!ok) {
    stop(paste0(
      "Worker python is: ", py_config()$python, "\n",
      "Missing module 'torch'. Install into this python, e.g.:\n",
      "  ", py_config()$python, " -m pip install torch\n",
      "Original error: ", err
    ))
  }

  source_python("code/DNN.py")
  source_python("code/NN_class.py")
  NULL
})

## ---------------------------------------------------------
## 7) Main execution loop (COMPLETE)
## ---------------------------------------------------------
cat("=== Starting parallel DFR simulations for covariance matrices ===\n")
cat("Reading data from:", data_dir, "\n")
cat("Sample sizes:", paste(sample_sizes, collapse = ", "), "\n")
cat("Simulations per size:", num_simulations, "\n\n")

for (n in sample_sizes) {
  cat("=== Running n =", n, "===\n")

  results <- foreach(
    seed = seq_len(num_simulations),
    .packages = c("frechet", "vegan")
  ) %dopar% {
    run_single_session(seed, n)
  }

  all_avg_distances <- vapply(results, function(x) x$avg_frobenius_distance, numeric(1))
  all_seeds <- vapply(results, function(x) x$seed, integer(1))

  results_df <- data.frame(
    seed = all_seeds,
    avg_frobenius_distance = all_avg_distances
  )

  filename <- file.path("covariance", paste0(n, "_cov_DFR.csv"))
  write.csv(results_df, filename, row.names = FALSE)

  cat("  Mean Frobenius distance:", round(mean(all_avg_distances, na.rm = TRUE), 6), "\n")
  cat("  Std  Frobenius distance:", round(sd(all_avg_distances, na.rm = TRUE), 6), "\n")
  cat("  Saved:", filename, "\n\n")
}

stopCluster(cl)
cat("=== All simulations completed ===\n")
