# Parallel 10-Fold Cross-Validation for Single Index Spherical Regression (SIFR)
# Applied to emotional dataset
# This script performs 50 Monte Carlo runs, each with 10-fold cross-validation
# Uses 10 cores for parallel execution

library(dplyr)
library(parallel)
library(frechet)

# Set Python path for shuffle generation script
# Try to detect Python automatically, or set manually
python_path <- Sys.getenv("RETICULATE_PYTHON", unset = "")
if (python_path == "") {
  # Try common locations (prioritize /usr/local/bin/python3 where numpy is installed)
  possible_paths <- c(
    "/usr/local/bin/python3",
    "/usr/local/python-3.8.0/bin/python3",
    "/opt/miniconda3/bin/python",
    "/opt/anaconda3/bin/python",
    "/usr/bin/python3",
    "/usr/bin/python",
    "/opt/conda/bin/python",
    Sys.which("python3"),
    Sys.which("python"),
    file.path(Sys.getenv("HOME"), "miniconda3", "bin", "python"),
    file.path(Sys.getenv("HOME"), "anaconda3", "bin", "python")
  )
  
  for (path in possible_paths) {
    if (path != "" && file.exists(path)) {
      python_path <- path
      break
    }
  }
}

if (python_path == "" || !file.exists(python_path)) {
  stop("Python not found. Please set the Python path manually by editing the script or setting RETICULATE_PYTHON environment variable.")
}

cat("Using Python at:", python_path, "\n")

# Test if numpy is available in Python
cat("Testing numpy availability...\n")
# Use a simpler test command that avoids quote issues
test_cmd <- paste0(python_path, " -c \"import numpy; print('OK')\"")
test_status <- system(test_cmd, intern = TRUE, ignore.stderr = TRUE)
exit_code <- attr(test_status, "status")
# If exit_code is NULL or 0, the command succeeded (numpy is available)
if (!is.null(exit_code) && exit_code != 0) {
  cat("ERROR: numpy is not available in Python at:", python_path, "\n")
  cat("Please install numpy using:\n")
  cat("  ", python_path, "-m pip install numpy\n")
  stop("numpy is required but not found. Please install numpy for the Python interpreter.")
}
cat("✓ numpy is available\n")

# Path to the Python script for shuffle generation
python_script <- "test_numpy.py"

# Function to get numpy shuffle using Python script (matching test.R)
get_numpy_shuffle <- function(n, session_id, python = python_path) {
  # Call the Python script
  out <- system2(python, c(python_script, n, session_id), stdout = TRUE)
  # Convert to integer and add 1 for 1-based indexing in R
  as.integer(out) + 1L
}

# Source the SIdxSpheReg function
cat("Loading SIdxSpheReg function...\n")
if (file.exists("SIdxSpheReg.R")) {
  source("SIdxSpheReg.R")
  cat("✓ SIdxSpheReg loaded\n")
} else {
  stop("SIdxSpheReg.R not found. Please ensure SIdxSpheReg.R is in the same directory.")
}

# Read the emotional dataset
data_file <- "../data_emotion.csv"
if (!file.exists(data_file)) {
  stop(paste("Data file not found:", data_file))
}

cat("Reading data from:", data_file, "\n")
data <- read.csv(data_file)

# Extract response variables (b2_1 to b2_4)
response_cols <- c("b2_1", "b2_2", "b2_3", "b2_4")
responses <- as.matrix(data[, response_cols])

# Normalize responses to sum to 1 (compositional constraint)
# The values appear to be percentages, so divide by 100
responses <- responses / 100.0
# Ensure they sum to exactly 1 (handle any rounding issues)
row_sums <- rowSums(responses)
responses <- responses / row_sums

# Take square root of responses (they are on sphere S^3)
Y <- sqrt(responses)

# Extract predictor variables (all other columns)
predictor_cols <- setdiff(colnames(data), response_cols)
X <- as.matrix(data[, predictor_cols])

cat("\nData shapes:\n")
cat("  X (predictors):", nrow(X), "x", ncol(X), "\n")
cat("  Y (square-root responses):", nrow(Y), "x", ncol(Y), "\n")
cat("  Number of features (p):", ncol(X), "\n")
cat("  Total samples:", nrow(X), "\n")

# Function to run a single 10-fold cross-validation session
run_single_session <- function(session_id, X, Y, python_path, python_script) {
  # Reload SIdxSpheReg function for each worker (needed for parallel execution)
  if (file.exists("SIdxSpheReg.R")) {
    source("SIdxSpheReg.R")
  }
  
  n <- nrow(X)
  n_folds <- 10
  
  # Get shuffle using Python script (matching Python implementation)
  # Use session_id directly (1-100) as numpy random seed
  indices <- get_numpy_shuffle(n, session_id, python_path)
  
  # Validate indices
  if (any(is.na(indices)) || length(indices) != n || length(unique(indices)) != n) {
    stop(paste("Invalid indices from shuffle for session", session_id))
  }
  
  # Calculate fold size
  fold_size <- floor(n / n_folds)
  remainder <- n %% n_folds
  
  # Store MSPE for each fold
  mspe_folds <- numeric(n_folds)
  
  for (fold_idx in 1:n_folds) {
    # Calculate test indices for this fold (1 fold = 10% of data)
    start_idx <- (fold_idx - 1) * fold_size + min(fold_idx - 1, remainder) + 1
    end_idx <- start_idx + fold_size - 1 + ifelse(fold_idx <= remainder, 1, 0)
    
    # Validate indices before using
    if (is.na(start_idx) || is.na(end_idx) || start_idx < 1 || end_idx > length(indices)) {
      mspe_folds[fold_idx] <- NA
      next
    }
    
    test_idx <- indices[start_idx:end_idx]
    train_idx <- indices[-c(start_idx:end_idx)]  # All remaining 9 folds for training
    
    # Further split training into train (80%) and validation (20%)
    n_train_val <- length(train_idx)
    n_train <- floor(0.8 * n_train_val)
    # Use numpy shuffle for train/val split as well (with fold-specific seed)
    train_val_perm <- get_numpy_shuffle(n_train_val, session_id + fold_idx, python_path)
    
    # Validate train_val_perm before using
    if (any(is.na(train_val_perm)) || length(train_val_perm) != n_train_val) {
      mspe_folds[fold_idx] <- NA
      next
    }
    
    # Map shuffled indices back to actual train_idx
    train_idx_final <- train_idx[train_val_perm[1:n_train]]
    val_idx <- train_idx[train_val_perm[(n_train + 1):n_train_val]]
    
    # Validate final indices
    if (any(is.na(train_idx_final)) || any(is.na(val_idx)) || 
        length(train_idx_final) == 0 || length(val_idx) == 0) {
      mspe_folds[fold_idx] <- NA
      next
    }
    
    # Create splits
    X_train <- X[train_idx_final, , drop = FALSE]
    X_val <- X[val_idx, , drop = FALSE]
    X_test <- X[test_idx, , drop = FALSE]
    
    Y_train <- Y[train_idx_final, , drop = FALSE]
    Y_val <- Y[val_idx, , drop = FALSE]
    Y_test <- Y[test_idx, , drop = FALSE]
    
    # Standardize X using training set statistics
    X_train_mean <- colMeans(X_train)
    X_train_sd <- apply(X_train, 2, sd)
    # Avoid division by zero
    X_train_sd[X_train_sd == 0] <- 1
    
    # Standardize training set
    X_train_scaled <- scale(X_train, center = X_train_mean, scale = X_train_sd)
    
    # Standardize validation and test sets using training set statistics
    X_val_scaled <- scale(X_val, center = X_train_mean, scale = X_train_sd)
    X_test_scaled <- scale(X_test, center = X_train_mean, scale = X_train_sd)
    
    n_test <- nrow(X_test)
    
    # Run SIdxSpheReg on training data
    tryCatch({
      sidx_result <- SIdxSpheReg(X_train_scaled, Y_train, verbose = FALSE, iter=50)
      
      # Get estimated theta and bandwidth
      est_theta <- sidx_result$est
      est_bw <- sidx_result$bw
      
      # Compute projections: Z_train = X_train_scaled %*% est_theta and Z_test = X_test_scaled %*% est_theta
      Z_train <- as.vector(X_train_scaled %*% est_theta)  # 1D projected training values
      Z_test <- as.vector(X_test_scaled %*% est_theta)   # 1D projected test values
      
      # Initialize predicted_responses matrix
      predicted_responses <- matrix(0, nrow = n_test, ncol = ncol(Y_test))
      
      # Use LocSpheReg from frechet package for local Fréchet regression
      # Note: LocSpheReg expects xin and xout as vectors and yin as spherical data
      tryCatch({
        res <- LocSpheReg(
          xin = Z_train,
          yin = Y_train,
          xout = Z_test,
          optns = list(bw = est_bw, kernel = "gauss")
        )
        # Extract predicted responses
        predicted_responses <- res$yout
      }, error = function(e) {
        # Fallback: use mean of training responses for all test points
        message(sprintf("[Session %d] Fold %d: Warning - LocSpheReg failed, using mean response as fallback: %s", 
                       session_id, fold_idx, e$message))
        flush.console()
        mean_response <- colMeans(Y_train)
        mean_response <- mean_response / l2norm(mean_response)
        for (i in 1:n_test) {
          predicted_responses[i, ] <<- mean_response
        }
      })
      
      # Calculate geodesic distances between predicted and true responses
      geodesic_distances <- numeric(n_test)
      for (i in 1:n_test) {
        if (i <= nrow(predicted_responses) && i <= nrow(Y_test)) {
          pred_vec <- as.numeric(predicted_responses[i, ])
          true_vec <- as.numeric(Y_test[i, ])
          
          # Ensure both are unit vectors
          pred_vec <- pred_vec / l2norm(pred_vec)
          true_vec <- true_vec / l2norm(true_vec)
          
          # Calculate geodesic distance
          geodesic_distances[i] <- SpheGeoDist(pred_vec, true_vec)
        } else {
          geodesic_distances[i] <- NA
        }
      }
      
      # MSPE is the mean squared geodesic distance
      valid_distances <- geodesic_distances[!is.na(geodesic_distances)]
      if (length(valid_distances) > 0) {
        mspe <- mean(valid_distances^2)
        mspe_folds[fold_idx] <- mspe
        # Print with session ID for parallel execution visibility
        message(sprintf("[Session %d] Fold %d/%d MSPE: %.6f", 
                       session_id, fold_idx, n_folds, mspe))
        flush.console()
      } else {
        mspe_folds[fold_idx] <- NA
        message(sprintf("[Session %d] Fold %d/%d MSPE: NA (no valid distances)", 
                       session_id, fold_idx, n_folds))
        flush.console()
      }
      
    }, error = function(e) {
      message(sprintf("[Session %d] Error in fold %d/%d: %s", 
                     session_id, fold_idx, n_folds, e$message))
      flush.console()
      mspe_folds[fold_idx] <- NA
    })
  }
  
  # Calculate mean MSPE for this session
  mean_mspe <- mean(mspe_folds, na.rm = TRUE)
  
  message(sprintf("[Session %d] Completed - Mean MSPE: %.6f (across %d folds)", 
                 session_id, mean_mspe, n_folds))
  flush.console()
  
  return(list(
    session_id = session_id,
    mspe_folds = mspe_folds,
    mean_mspe = mean_mspe
  ))
}

# Parallel execution parameters
n_sessions <- 50  # 50 Monte Carlo runs
n_cores <- 50    # Use 50 cores

cat("\n", paste(rep("=", 50), collapse = ""), "\n")
cat("Starting Parallel Cross-Validation\n")
cat("  Sessions:", n_sessions, "\n")
cat("  Cores:", n_cores, "\n")
cat("  Folds per session:", 10, "\n")
cat(paste(rep("=", 50), collapse = ""), "\n")

# Run parallel sessions
cat("\nRunning", n_sessions, "parallel sessions...\n")
start_time <- Sys.time()

# Use mclapply for parallel execution
# Pass python_path and python_script to each worker
results_list <- mclapply(1:n_sessions, function(session_id) {
  run_single_session(session_id, X, Y, python_path, python_script)
}, mc.cores = n_cores)

end_time <- Sys.time()
elapsed_time <- difftime(end_time, start_time, units = "secs")
cat("Completed in", round(elapsed_time, 2), "seconds\n")

# Extract mean MSPE for each session
mean_mspe_results <- sapply(results_list, function(x) x$mean_mspe)

# Create results dataframe
results_df <- data.frame(
  iteration = 1:n_sessions,
  seed = 1:n_sessions,
  mean_mspe = mean_mspe_results
)

# Save results to CSV
output_file <- "CV_SIFR_emotional_results.csv"
write.csv(results_df, file = output_file, row.names = FALSE)
cat("\nResults saved to:", output_file, "\n")

# Print summary statistics
cat("\n", paste(rep("=", 50), collapse = ""), "\n")
cat("Summary Statistics\n")
cat(paste(rep("=", 50), collapse = ""), "\n")

# Remove NA values for statistics
results_clean <- mean_mspe_results[!is.na(mean_mspe_results)]

if (length(results_clean) > 0) {
  cat("Mean MSPE across all sessions:", sprintf("%.6f", mean(results_clean)), "\n")
  if (length(results_clean) > 1) {
    cat("Std MSPE:", sprintf("%.6f", sd(results_clean)), "\n")
  } else {
    cat("Std MSPE: N/A (only one value)\n")
  }
  cat("Min MSPE:", sprintf("%.6f", min(results_clean)), "\n")
  cat("Max MSPE:", sprintf("%.6f", max(results_clean)), "\n")
} else {
  cat("No valid MSPE values (all NA)\n")
}

cat("NA values:", sum(is.na(mean_mspe_results)), "out of", length(mean_mspe_results), "\n")
cat("Valid values:", length(results_clean), "out of", length(mean_mspe_results), "\n")

cat("\nParallel cross-validation completed successfully!\n")

