# Parallel 10-Fold Cross-Validation for Global Spherical Regression (GloSpheReg)
# Applied to emotional dataset
# This script performs 100 parallel sessions of 10-fold cross-validation

library(dplyr)
library(parallel)
library(trust)

# Source the GFR functions
if (file.exists("GFR_code.R")) {
  source("GFR_code.R")
  cat("✓ GFR_code.R loaded\n")
} else {
  stop("GFR_code.R not found. Please ensure GFR_code.R is in the same directory.")
}

# Set Python path for shuffle generation script
# Try to detect Python automatically, or set manually
python_path <- Sys.getenv("RETICULATE_PYTHON", unset = "")
if (python_path == "") {
  # Try common locations (prioritize /usr/local/bin/python3 where numpy is installed)
  possible_paths <- c(
    "/usr/local/bin/python3",
    "/usr/local/python-3.8.0/bin/python3",
    "/opt/miniconda3/bin/python",
    "/opt/anaconda3/bin/python",
    "/usr/bin/python3",
    "/usr/bin/python",
    "/opt/conda/bin/python",
    Sys.which("python3"),
    Sys.which("python"),
    file.path(Sys.getenv("HOME"), "miniconda3", "bin", "python"),
    file.path(Sys.getenv("HOME"), "anaconda3", "bin", "python")
  )
  
  for (path in possible_paths) {
    if (path != "" && file.exists(path)) {
      python_path <- path
      break
    }
  }
}

if (python_path == "" || !file.exists(python_path)) {
  stop("Python not found. Please set the Python path manually by editing the script or setting RETICULATE_PYTHON environment variable.")
}

cat("Using Python at:", python_path, "\n")

# Test if numpy is available in Python
cat("Testing numpy availability...\n")
# Use a simpler test command that avoids quote issues
test_cmd <- paste0(python_path, " -c \"import numpy; print('OK')\"")
test_status <- system(test_cmd, intern = TRUE, ignore.stderr = TRUE)
exit_code <- attr(test_status, "status")
# If exit_code is NULL or 0, the command succeeded (numpy is available)
if (!is.null(exit_code) && exit_code != 0) {
  cat("ERROR: numpy is not available in Python at:", python_path, "\n")
  cat("Please install numpy using:\n")
  cat("  ", python_path, "-m pip install numpy\n")
  stop("numpy is required but not found. Please install numpy for the Python interpreter.")
}
cat("✓ numpy is available\n")

# Path to the Python script for shuffle generation
python_script <- "../SIFR/test_numpy.py"

# Function to get numpy shuffle using Python script (matching test.R)
get_numpy_shuffle <- function(n, session_id, python = python_path) {
  # Call the Python script
  out <- system2(python, c(python_script, n, session_id), stdout = TRUE)
  # Convert to integer and add 1 for 1-based indexing in R
  as.integer(out) + 1L
}

# Read the emotional dataset
data_file <- "../data_emotion.csv"
if (!file.exists(data_file)) {
  stop(paste("Data file not found:", data_file))
}

cat("Reading data from:", data_file, "\n")
data <- read.csv(data_file)

# Use only first 100 data points
data <- data[1:100, ]

# Extract response variables (b2_1 to b2_4)
response_cols <- c("b2_1", "b2_2", "b2_3", "b2_4")
responses <- as.matrix(data[, response_cols])

# Normalize responses to sum to 1 (compositional constraint)
# The values appear to be percentages, so divide by 100
responses <- responses / 100.0
# Ensure they sum to exactly 1 (handle any rounding issues)
row_sums <- rowSums(responses)
responses <- responses / row_sums

# Take square root of responses (they are on sphere S^3)
Y <- sqrt(responses)

# Extract predictor variables (all other columns)
predictor_cols <- setdiff(colnames(data), response_cols)
X <- as.matrix(data[, predictor_cols])

cat("\nData shapes:\n")
cat("  X (predictors):", nrow(X), "x", ncol(X), "\n")
cat("  Y (square-root responses):", nrow(Y), "x", ncol(Y), "\n")
cat("  Number of features (p):", ncol(X), "\n")
cat("  Total samples:", nrow(X), "\n")

# Function to run a single 10-fold cross-validation session
run_single_session <- function(session_id, X, Y, python_path, python_script) {
  n <- nrow(X)
  n_folds <- 10
  
  # Get shuffle using Python script (matching Python implementation)
  # Use session_id directly (1-100) as numpy random seed
  indices <- get_numpy_shuffle(n, session_id, python_path)
  
  # Calculate fold size
  fold_size <- floor(n / n_folds)
  remainder <- n %% n_folds
  
  # Store MSPE for each fold
  mspe_folds <- numeric(n_folds)
  
  for (fold_idx in 1:n_folds) {
    # Calculate test indices for this fold (1 fold = 10% of data)
    start_idx <- (fold_idx - 1) * fold_size + min(fold_idx - 1, remainder) + 1
    end_idx <- start_idx + fold_size - 1 + ifelse(fold_idx <= remainder, 1, 0)
    
    test_idx <- indices[start_idx:end_idx]
    train_idx <- indices[-c(start_idx:end_idx)]  # All remaining 9 folds for training
    
    # Create splits
    X_train <- X[train_idx, , drop = FALSE]
    X_test <- X[test_idx, , drop = FALSE]
    
    Y_train <- Y[train_idx, , drop = FALSE]
    Y_test <- Y[test_idx, , drop = FALSE]
    
    # Standardize X using training set statistics
    X_train_mean <- colMeans(X_train)
    X_train_sd <- apply(X_train, 2, sd)
    # Avoid division by zero
    X_train_sd[X_train_sd == 0] <- 1
    
    # Standardize training set
    X_train_scaled <- scale(X_train, center = X_train_mean, scale = X_train_sd)
    
    # Standardize test set using training set statistics
    X_test_scaled <- scale(X_test, center = X_train_mean, scale = X_train_sd)
    
    n_train <- nrow(X_train)
    n_test <- nrow(X_test)
    
    # Run global spherical regression using gsr function
    # gsr expects:
    #   xin: vector (n,) or n x p matrix of predictors
    #   yin: n x m matrix of spherical responses (unit vectors)
    #   xout: vector (k,) or k x p matrix of output predictors
    tryCatch({
      res <- gsr(xin = X_train_scaled, yin = Y_train, xout = X_test_scaled)
      
      # Extract predicted spherical responses for test set
      y_pred <- res$yout
      
      # Ensure y_pred is a matrix
      if (!is.matrix(y_pred)) {
        y_pred <- as.matrix(y_pred)
      }
      
      # Ensure predicted responses are unit vectors (normalize if needed)
      for (i in 1:n_test) {
        if (i <= nrow(y_pred)) {
          y_pred[i, ] <- y_pred[i, ] / sqrt(sum(y_pred[i, ]^2))
        }
      }
      
      # Calculate geodesic distances between predicted and true responses
      geodesic_distances <- numeric(n_test)
      for (i in 1:n_test) {
        if (i <= nrow(y_pred) && i <= nrow(Y_test)) {
          pred_vec <- as.numeric(y_pred[i, ])
          true_vec <- as.numeric(Y_test[i, ])
          
          # Ensure both are unit vectors
          pred_vec <- pred_vec / sqrt(sum(pred_vec^2))
          true_vec <- true_vec / sqrt(sum(true_vec^2))
          
          # Calculate geodesic distance (arccos of dot product)
          dot_product <- sum(pred_vec * true_vec)
          dot_product <- pmax(-1, pmin(1, dot_product))  # Clamp to [-1, 1]
          geodesic_distances[i] <- acos(dot_product)
        } else {
          geodesic_distances[i] <- NA
        }
      }
      
      # MSPE is the mean squared geodesic distance
      mspe <- mean(geodesic_distances^2, na.rm = TRUE)
      
      mspe_folds[fold_idx] <- mspe
      
    }, error = function(e) {
      cat("Error in session", session_id, "fold", fold_idx, ":", e$message, "\n")
      mspe_folds[fold_idx] <- NA
    })
  }
  
  # Calculate mean MSPE for this session
  mean_mspe <- mean(mspe_folds, na.rm = TRUE)
  
  cat("  Completed session", session_id, "- Mean MSPE:", sprintf("%.6f", mean_mspe), "\n")
  
  return(list(
    session_id = session_id,
    mspe_folds = mspe_folds,
    mean_mspe = mean_mspe
  ))
}

# Parallel execution parameters
n_sessions <- 100
n_cores <- 50

cat("\n", paste(rep("=", 50), collapse = ""), "\n")
cat("Starting Parallel Cross-Validation\n")
cat("  Sessions:", n_sessions, "\n")
cat("  Cores:", n_cores, "\n")
cat("  Folds per session:", 10, "\n")
cat(paste(rep("=", 50), collapse = ""), "\n")

# Run parallel sessions
cat("\nRunning", n_sessions, "parallel sessions...\n")
start_time <- Sys.time()

# Use mclapply for parallel execution
# Pass python_path and python_script to each worker
results_list <- mclapply(1:n_sessions, function(session_id) {
  run_single_session(session_id, X, Y, python_path, python_script)
}, mc.cores = n_cores)

end_time <- Sys.time()
elapsed_time <- difftime(end_time, start_time, units = "secs")
cat("Completed in", round(elapsed_time, 2), "seconds\n")

# Extract mean MSPE for each session
mean_mspe_results <- sapply(results_list, function(x) x$mean_mspe)

# Create results dataframe
results_df <- data.frame(
  iteration = 1:n_sessions,
  seed = 1:n_sessions,
  mean_mspe = mean_mspe_results
)

# Save results to CSV
output_file <- "CV_GFR_emotional_results.csv"
write.csv(results_df, file = output_file, row.names = FALSE)
cat("\nResults saved to:", output_file, "\n")

# Print summary statistics
cat("\n", paste(rep("=", 50), collapse = ""), "\n")
cat("Summary Statistics\n")
cat(paste(rep("=", 50), collapse = ""), "\n")

# Remove NA values for statistics
results_clean <- mean_mspe_results[!is.na(mean_mspe_results)]

if (length(results_clean) > 0) {
  cat("Mean MSPE across all sessions:", sprintf("%.6f", mean(results_clean)), "\n")
  if (length(results_clean) > 1) {
    cat("Std MSPE:", sprintf("%.6f", sd(results_clean)), "\n")
  } else {
    cat("Std MSPE: N/A (only one value)\n")
  }
  cat("Min MSPE:", sprintf("%.6f", min(results_clean)), "\n")
  cat("Max MSPE:", sprintf("%.6f", max(results_clean)), "\n")
} else {
  cat("No valid MSPE values (all NA)\n")
}

cat("NA values:", sum(is.na(mean_mspe_results)), "out of", length(mean_mspe_results), "\n")
cat("Valid values:", length(results_clean), "out of", length(mean_mspe_results), "\n")

cat("\nParallel cross-validation completed successfully!\n")

