# Parallel DFR Simulation Script
# Tests all combinations of link functions and sample sizes with 200 sessions each
library(reticulate)
library(frechet)
library(vegan)
library(parallel)
library(doParallel)
library(foreach)
library(doRNG)

# Set up Python environment with automatic detection
# Try to automatically detect Python path for Linux server
python_paths <- c(
  "/usr/bin/python3",
  "/usr/bin/python",
  "/opt/conda/bin/python", 
  "/opt/miniconda3/bin/python",
  "/usr/local/bin/python3",
  "/usr/local/bin/python"
)

python_found <- FALSE
for (path in python_paths) {
  if (file.exists(path)) {
    tryCatch({
      use_python(path)
      cat("Using Python at:", path, "\n")
      python_found <- TRUE
      break
    }, error = function(e) {
      cat("Failed to use Python at:", path, "\n")
    })
  }
}

if (!python_found) {
  stop("No suitable Python installation found. Please install Python or specify the correct path.")
}

# Test Python environment before starting parallel processing
cat("Testing Python environment...\n")
tryCatch({
  # Test basic imports
  torch <- import("torch")
  numpy <- import("numpy")
  cat("✓ torch imported successfully\n")
  cat("✓ numpy imported successfully\n")
  
  # Test simulation module
  simulation_module <- import_from_path("simulationdatagenerator", path = ".")
  cat("✓ simulationdatagenerator imported successfully\n")
  
  # Test basic functionality
  test_result <- simulation_module$generate_simulation_data_torch_true(n=10L, qf_size=10L, p=4L, link="linear", seed=1L)
  cat("✓ Test simulation run successful\n")
  
}, error = function(e) {
  cat("❌ Python environment test failed:\n")
  cat("Error:", e$message, "\n")
  cat("Python path:", py_config()$python, "\n")
  cat("Python version:", py_config()$version, "\n")
  cat("\nPlease ensure the following Python packages are installed:\n")
  cat("- torch (PyTorch)\n")
  cat("- numpy\n")
  cat("- simulationdatagenerator.py (in current directory)\n")
  cat("\nYou can install them using:\n")
  cat("pip install torch numpy\n")
  stop("Python environment test failed")
})

# Source all required R functions
source("code/DFR.R")
source("code/lrem.R")
source("code/lnr.R")
source("code/lcm.R")
source("code/kerFctn.R")

# Source Python functions
reticulate::source_python("code/DNN.py")

# Set parameters
p <- 4L
qf_size <- 100L

# Define all combinations to test
sample_sizes <- c(250L, 500L, 1250L, 2500L)
link_functions <- c("linear", "quadratic", "exp")

# Create results folder if it doesn't exist
if (!dir.exists("results")) {
  dir.create("results")
}

# Function to run a single session
run_single_session <- function(seed, n, link_type, qf_size, p) {
  
  # Reload necessary functions for each iteration (like in the working example)
  function_path = "code"
  function_sources = list.files(function_path, pattern = "*.R$", full.names = TRUE, ignore.case = TRUE)
  sapply(function_sources, source, .GlobalEnv)
  reticulate::source_python('code/DNN.py')
  
  # Set Python path and import module for each worker
  python_paths <- c(
    "/usr/bin/python3",
    "/usr/bin/python",
    "/opt/conda/bin/python", 
    "/opt/miniconda3/bin/python",
    "/usr/local/bin/python3",
    "/usr/local/bin/python"
  )
  
  python_found <- FALSE
  for (path in python_paths) {
    if (file.exists(path)) {
      tryCatch({
        use_python(path)
        python_found <- TRUE
        break
      }, error = function(e) {
        # Continue to next path
      })
    }
  }
  
  if (!python_found) {
    stop("No suitable Python installation found for worker.")
  }
  
  # Import required Python modules
  tryCatch({
    torch <- import("torch")
    numpy <- import("numpy")
    simulation_module <- import_from_path("simulationdatagenerator", path = ".")
  }, error = function(e) {
    cat("Error importing Python modules:", e$message, "\n")
    stop("Failed to import required Python modules")
  })
  
  # Generate simulation data
  result <- simulation_module$generate_simulation_data_torch_true(n=n, qf_size=qf_size, p=p, link=link_type, seed=seed)
  
  # Extract results and convert to R
  X <- result[[1]]
  Y <- result[[2]]
  theta_true <- result[[3]]
  mu <- result[[4]]
  sigma <- result[[5]]
  
  cat("Step 5: Converting to R arrays...\n")
  X_r <- as.array(X$numpy())
  Y_r <- as.array(Y$numpy())
  cat("✓ Converted to R arrays\n")
  
  cat("Step 6: Data splitting...\n")
  set.seed(seed)  # Set R seed for reproducible splitting
  idx <- sample(1:n)
  n_train <- as.integer(0.4 * n)  # 40% for training
  n_val <- as.integer(0.1 * n)    # 10% for validation
  n_test <- 200L   # Fixed test set size
  
  train_idx <- idx[1:n_train]
  val_idx <- idx[(n_train + 1):(n_train + n_val)]
  test_idx <- idx[(n_train + n_val + 1):(n_train + n_val + n_test)]
  
  X_train_matrix <- X_r[train_idx, , drop = FALSE]
  Y_train_matrix <- Y_r[train_idx, , drop = FALSE]
  X_val_matrix <- X_r[val_idx, , drop = FALSE]
  Y_val_matrix <- Y_r[val_idx, , drop = FALSE]
  X_test_matrix <- X_r[test_idx, , drop = FALSE]
  cat("✓ Data split: train=", n_train, ", validation=", n_val, ", test=", n_test, "\n")
  
  cat("Step 7: Converting to dataframes...\n")
  X_train <- as.data.frame(X_train_matrix)
  colnames(X_train) <- c("V1", "V2", "V3", "V4")
  
  X_val <- as.data.frame(X_val_matrix)
  colnames(X_val) <- c("V1", "V2", "V3", "V4")
  
  X_test <- as.data.frame(X_test_matrix)
  colnames(X_test) <- c("V1", "V2", "V3", "V4")
  cat("✓ Dataframes created\n")
  
  cat("Step 8: Converting Y to list format...\n")
  Y_train <- lapply(1:nrow(Y_train_matrix), function(i) {
    sort(Y_train_matrix[i, ])
  })
  cat("✓ Y converted to list format\n")
  
  # Run DFR regression with same settings as minimal script
  cat("Step 9: Calling DFR...\n")
  res_dfr = DFR(y = Y_train, x = X_train, xout = X_test, 
                 optns = list(type = "measure", manifold = list(method = "isomap", k = ifelse(n == 100, 20, 0.1 * n)), 
                                r = 2, layer = 4, hidden = ifelse(n == 100, 32, 64), dropout = 0.3, lr = 0.0005, 
                                num_epochs = 2000, seed = 1))
  
  # Extract predicted quantile functions for test set
  yPred <- res_dfr$yPred
  
  # Debug: Check structure of yPred
  cat("✓ DFR completed. yPred structure:\n")
  cat("  Class:", class(yPred), "\n")
  cat("  Length/Dim:", if(is.list(yPred)) length(yPred) else dim(yPred), "\n")
  if(is.list(yPred) && length(yPred) > 0) {
    cat("  First element class:", class(yPred[[1]]), "\n")
    cat("  First element length:", length(yPred[[1]]), "\n")
  }
  
  # Calculate true quantile functions for test set
  quantile_grid <- seq(0.01, 0.99, length.out = qf_size)
  mu_test <- as.array(result[[4]]$numpy())[test_idx]
  sigma_test <- as.array(result[[5]]$numpy())[test_idx]
  Y_test_true <- matrix(0, nrow = length(test_idx), ncol = qf_size)
  for (i in 1:length(test_idx)) {
    Y_test_true[i, ] <- qnorm(quantile_grid, mean = mu_test[i], sd = sigma_test[i])
  }
  
  # Calculate L2 distance between predicted and true quantile functions
  cat("\nStep 12: Calculating L2 distances...\n")
  l2_distances <- numeric(nrow(X_test))
  for (i in 1:nrow(X_test)) {
    if (i <= nrow(yPred) && i <= nrow(Y_test_true)) {
      # Ensure both vectors are numeric and finite
      pred_vec <- as.numeric(yPred[i, ])
      true_vec <- as.numeric(Y_test_true[i, ])
      
      # Replace any remaining NA/Inf values
      pred_vec[!is.finite(pred_vec)] <- 0
      true_vec[!is.finite(true_vec)] <- 0
      
      l2_distances[i] <- sqrt(mean((pred_vec - true_vec)^2))
    } else {
      l2_distances[i] <- NA
    }
  }
  
  avg_l2_distance <- mean(l2_distances)
  
  # Print completion message
  cat(sprintf("[%s, n=%d] Session %d completed - L2 distance: %.6f\n",
              link_type, n, seed, avg_l2_distance))
  
  # Return results for this session
  return(list(
    seed = seed,
    avg_l2_distance = avg_l2_distance,
    l2_distances = l2_distances,
    n_test = nrow(X_test),
    n = n,
    link_type = link_type
  ))
}

# Set up parallel computing - use mclapply for better compatibility
num_cores <- detectCores() - 1
cat("Setting up parallel computing with", num_cores, "cores\n")

# Use mclapply instead of cluster-based approach
if (.Platform$OS.type == "unix") {
  cat("Using mclapply for parallel processing\n")
  use_mclapply <- TRUE
} else {
  cat("Using cluster-based parallel processing\n")
  use_mclapply <- FALSE
  cl <- makeCluster(num_cores)
  registerDoParallel(cl)
}

cat("=== Parallel DFR Simulation Script ===\n")
cat("Starting parallel computation with", num_cores, "cores\n")
cat("Running all combinations of sample sizes and link functions\n")
cat("Sample sizes:", paste(sample_sizes, collapse=", "), "\n")
cat("Link functions:", paste(link_functions, collapse=", "), "\n")
cat("Sessions per combination: 200\n")
cat("Total sessions:", length(sample_sizes) * length(link_functions) * 200, "\n\n")



# Loop through all combinations
for (n in sample_sizes) {
  for (link_type in link_functions) {
    
    cat("=== Running combination: n =", n, ", link =", link_type, "===\n")
    
    # Set seed for this specific combination to ensure reproducibility
    cat("  Running 200 sessions with seeds 1-200\n")
    
    # Run parallel computation for this combination
    if (use_mclapply) {
      # Use mclapply for Unix systems
      results <- mclapply(1:200, function(seed) {
        run_single_session(seed, n, link_type, qf_size, p)
      }, mc.cores = num_cores)
    } else {
      # Use foreach for Windows systems
      results <- foreach(seed = 1:200, 
                        .packages = c("reticulate", "frechet", "vegan"),
                        .combine = 'c') %dopar% {
        list(run_single_session(seed, n, link_type, qf_size, p))
      }
    }
    
    # Extract results
    if (use_mclapply) {
      # mclapply returns a list directly
      all_avg_l2_distances <- sapply(results, function(x) x$avg_l2_distance)
      all_seeds <- sapply(results, function(x) x$seed)
    } else {
      # foreach returns a combined result
      all_avg_l2_distances <- sapply(results, function(x) x$avg_l2_distance)
      all_seeds <- sapply(results, function(x) x$seed)
    }
    
    # Create results dataframe
    results_summary <- data.frame(
      seed = all_seeds,
      avg_l2_distance = all_avg_l2_distances,
      n = n,
      link_type = link_type
    )
    
    # Save results with specified naming convention: link_size_dfr.csv
    filename <- paste0("results/", link_type, "_", n, "_dfr.csv")
    write.csv(results_summary, filename, row.names = FALSE)
    
    # Print summary statistics for this combination
    cat("  Mean L2 distance:", round(mean(all_avg_l2_distances), 6), "\n")
    cat("  Std L2 distance:", round(sd(all_avg_l2_distances), 6), "\n")
    cat("  Results saved to:", filename, "\n")
    cat("  Combination completed successfully!\n\n")
  }
}

# Stop parallel cluster if using cluster approach
if (!use_mclapply) {
  stopCluster(cl)
}

cat("=== All combinations completed! ===\n")
cat("Results saved in results/ folder with naming convention: link_size_dfr.csv\n\n")

# Print summary of all files created
cat("Files created:\n")
for (n in sample_sizes) {
  for (link_type in link_functions) {
    filename <- paste0("results/", link_type, "_", n, "_dfr.csv")
    cat("-", filename, "\n")
  }
}

cat("\nParallel DFR analysis with all combinations completed!\n")
cat("Script finished at:", Sys.time(), "\n")
