# Load required libraries
library(reticulate)
library(frechet)
library(parallel)
library(doParallel)
library(foreach)
library(Matrix)
library(osqp)

# Source the SIdxDenReg function
source("SIdxDenReg.R")

# Set Python path and import the simulation module
# Try to automatically detect Python path for Linux server
python_paths <- c(
  "/usr/bin/python3",
  "/usr/bin/python",
  "/opt/conda/bin/python",
  "/opt/miniconda3/bin/python",
  "/usr/local/bin/python3",
  "/usr/local/bin/python"
)

python_found <- FALSE
for (path in python_paths) {
  if (file.exists(path)) {
    tryCatch({
      use_python(path)
      cat("Using Python at:", path, "\n")
      python_found <- TRUE
      break
    }, error = function(e) {
      cat("Failed to use Python at:", path, "\n")
    })
  }
}

if (!python_found) {
  # Try to use reticulate's default Python
  tryCatch({
    cat("Trying reticulate's default Python...\n")
    # This will use reticulate's default Python detection
  }, error = function(e) {
    stop("No suitable Python installation found. Please install Python or specify the correct path.")
  })
}

# Test Python environment before starting parallel processing
cat("Testing Python environment...\n")
tryCatch({
  # Test basic imports
  torch <- import("torch")
  numpy <- import("numpy")
  cat("✓ torch imported successfully\n")
  cat("✓ numpy imported successfully\n")
  
  # Test simulation module
  simulation_module <- import_from_path("simulationdatagenerator", path = ".")
  cat("✓ simulationdatagenerator imported successfully\n")
  
  # Test basic functionality
  test_result <- simulation_module$generate_simulation_data_torch_true(n=10L, qf_size=10L, p=4L, link="linear", seed=1L)
  cat("✓ Test simulation run successful\n")
  
}, error = function(e) {
  cat("❌ Python environment test failed:\n")
  cat("Error:", e$message, "\n")
  cat("Python path:", py_config()$python, "\n")
  cat("Python version:", py_config()$version, "\n")
  cat("\nPlease ensure the following Python packages are installed:\n")
  cat("- torch (PyTorch)\n")
  cat("- numpy\n")
  cat("- simulationdatagenerator.py (in current directory)\n")
  cat("\nYou can install them using:\n")
  cat("pip install torch numpy\n")
  stop("Python environment test failed")
})

# Set parameters
p <- 4L
qf_size <- 100L

# Define all combinations to test
sample_sizes <- c(250L, 500L, 1250L, 2500L)
link_functions <- c("linear", "quadratic", "exp")

# Function to run a single session
run_single_session <- function(seed, n, link_type) {
  # Set Python path and import module for each worker
  # Try to automatically detect Python path for each worker
  python_paths <- c(
    "/usr/bin/python3",
    "/usr/bin/python",
    "/opt/conda/bin/python",
    "/opt/miniconda3/bin/python",
    "/usr/local/bin/python3",
    "/usr/local/bin/python"
  )
  
  python_found <- FALSE
  for (path in python_paths) {
    if (file.exists(path)) {
      tryCatch({
        use_python(path)
        python_found <- TRUE
        break
      }, error = function(e) {
        # Continue to next path
      })
    }
  }
  
  if (!python_found) {
    # Try to use reticulate's default Python
    tryCatch({
      # This will use reticulate's default Python detection
    }, error = function(e) {
      stop("No suitable Python installation found for worker.")
    })
  }
  
  # Import required Python modules
  tryCatch({
    torch <- import("torch")
    numpy <- import("numpy")
    simulation_module <- import_from_path("simulationdatagenerator", path = ".")
  }, error = function(e) {
    cat("Error importing Python modules:", e$message, "\n")
    cat("Python path being used:", py_config()$python, "\n")
    cat("Available Python packages:\n")
    tryCatch({
      py_run_string("import sys; print('\\n'.join(sys.path))")
    }, error = function(e2) {
      cat("Could not check Python path\n")
    })
    stop("Failed to import required Python modules")
  })
  
  # Generate simulation data
  result <- simulation_module$generate_simulation_data_torch_true(n=n, qf_size=qf_size, p=p, link=link_type, seed=seed)
  
  # Extract results and convert to R
  X <- result[[1]]
  Y <- result[[2]]
  theta_true <- result[[3]]
  mu <- result[[4]]
  sigma <- result[[5]]
  
  # Convert to R arrays
  X_r <- as.array(X$numpy())
  Y_r <- as.array(Y$numpy())
  theta_true_r <- as.array(theta_true$numpy())
  mu_r <- as.array(mu$numpy())
  sigma_r <- as.array(sigma$numpy())
  
  # Data splitting (same as simu_linear100.py)
  set.seed(seed)  # Set R seed for reproducible splitting
  idx <- sample(1:n)  # Random permutation
  n_train <- as.integer(0.4 * n)  # 40% for training
  n_val <- as.integer(0.1 * n)    # 10% for validation
  n_test <- 200L   # Fixed test set size
  
  idx_train <- idx[1:n_train]
  idx_val <- idx[(n_train+1):(n_train+n_val)]
  idx_test <- idx[(n_train+n_val+1):(n_train+n_val+n_test)]
  
  # Split data
  X_train <- X_r[idx_train, ]
  X_val <- X_r[idx_val, ]
  X_test <- X_r[idx_test, ]
  
  Y_train <- Y_r[idx_train, ]
  Y_val <- Y_r[idx_val, ]
  Y_test <- Y_r[idx_test, ]
  
  mu_train <- mu_r[idx_train]
  mu_val <- mu_r[idx_val]
  mu_test <- mu_r[idx_test]
  
  sigma_train <- sigma_r[idx_train]
  sigma_val <- sigma_r[idx_val]
  sigma_test <- sigma_r[idx_test]
  
  # Run SIdxDenReg on training set to get estimated theta
  sidx_result <- SIdxDenReg(X_train, Y_train)
  est_theta <- sidx_result$est
  
  # Use estimated theta to get projections for both training and test sets
  xin_train_proj <- X_train %*% est_theta  # 1D projections for training
  xout_test_proj <- X_test %*% est_theta   # 1D projections for test
  
  # Use LocDenReg to get predicted quantile functions for all test points at once
  # Note: LocDenReg works with 1D predictors, so we use the projected values
  loc_result <- LocDenReg(xin = xin_train_proj, qin = Y_train, xout = xout_test_proj)
  
  # Extract the quantile functions from the result
  predicted_curves <- loc_result$qout
  
  # Ensure predicted_curves is a matrix and has correct dimensions
  if (!is.matrix(predicted_curves)) {
    predicted_curves <- matrix(predicted_curves, nrow = length(xout_test_proj), ncol = qf_size)
  }
  
  # Convert to numeric matrix and handle any NA/Inf values
  predicted_curves <- matrix(as.numeric(predicted_curves), nrow = nrow(predicted_curves), ncol = ncol(predicted_curves))
  
  # Generate true quantile functions for test set based on mu and sigma
  # Create quantile grid from 0.01 to 0.99
  quantile_grid <- seq(0.01, 0.99, length.out = qf_size)
  
  # Generate true quantile functions for test set
  Y_test_true <- matrix(0, nrow = length(idx_test), ncol = qf_size)
  for (i in 1:length(idx_test)) {
    Y_test_true[i, ] <- qnorm(quantile_grid, mean = mu_test[i], sd = sigma_test[i])
  }
  
  Y_test_true <- matrix(as.numeric(Y_test_true), nrow = nrow(Y_test_true), ncol = ncol(Y_test_true))
  
  # Replace any Inf or NA values with finite numbers
  predicted_curves[!is.finite(predicted_curves)] <- 0
  Y_test_true[!is.finite(Y_test_true)] <- 0
  
  # Calculate L2 distance between predicted and true quantile functions
  l2_distances <- numeric(n_test)
  for (i in 1:n_test) {
    if (i <= nrow(predicted_curves) && i <= nrow(Y_test_true)) {
      # Ensure both vectors are numeric and finite
      pred_vec <- as.numeric(predicted_curves[i, ])
      true_vec <- as.numeric(Y_test_true[i, ])
      
      # Replace any remaining NA/Inf values
      pred_vec[!is.finite(pred_vec)] <- 0
      true_vec[!is.finite(true_vec)] <- 0
      
      l2_distances[i] <- sqrt(mean((pred_vec - true_vec)^2))
    } else {
      l2_distances[i] <- NA
    }
  }
  
  avg_l2_distance <- mean(l2_distances)
  
  # Print completion message
  cat(sprintf("[%s, n=%d] Session %d completed - L2 distance: %.6f\n", 
              link_type, n, seed, avg_l2_distance))
  
  # Return results for this session
  return(list(
    seed = seed,
    avg_l2_distance = avg_l2_distance,
    l2_distances = l2_distances,
    n_test = n_test,
    theta_true = theta_true_r,
    est_theta = est_theta,
    n = n,
    link_type = link_type
  ))
}

# Set up parallel computing
num_cores <- detectCores() - 1
cl <- makeCluster(num_cores)
registerDoParallel(cl)

cat("=== Local Frechet Regression Server Script ===\n")
cat("Starting parallel computation with", num_cores, "cores\n")
cat("Running all combinations of sample sizes and link functions\n")
cat("Sample sizes:", paste(sample_sizes, collapse=", "), "\n")
cat("Link functions:", paste(link_functions, collapse=", "), "\n")
cat("Sessions per combination: 200\n")
cat("Total sessions:", length(sample_sizes) * length(link_functions) * 200, "\n\n")

# Loop through all combinations
for (n in sample_sizes) {
  for (link_type in link_functions) {
    
    cat("=== Running combination: n =", n, ", link =", link_type, "===\n")
    
    # Run parallel computation for this combination
    results <- foreach(seed = 1:200, 
                      .packages = c("reticulate", "frechet", "Matrix", "osqp"),
                      .combine = 'c') %dopar% {
      list(run_single_session(seed, n, link_type))
    }
    
    # Extract results
    all_avg_l2_distances <- sapply(results, function(x) x$avg_l2_distance)
    all_seeds <- sapply(results, function(x) x$seed)
    est_thetas <- do.call(rbind, lapply(results, function(x) x$est_theta))
    
    # Create results dataframe
    results_summary <- data.frame(
      seed = all_seeds,
      avg_l2_distance = all_avg_l2_distances
    )
    
    # Save results with specified naming convention: link_size_IFR.csv
    filename <- paste0(link_type, "_", n, "_IFR.csv")
    write.csv(results_summary, filename, row.names = FALSE)
    
    # Save estimated thetas
    theta_filename <- paste0(link_type, "_", n, "_IFR_thetas.csv")
    write.csv(est_thetas, theta_filename, row.names = FALSE)
    
    # Print summary statistics for this combination
    cat("  Mean L2 distance:", round(mean(all_avg_l2_distances), 6), "\n")
    cat("  Std L2 distance:", round(sd(all_avg_l2_distances), 6), "\n")
    cat("  Results saved to:", filename, "\n")
    cat("  Thetas saved to:", theta_filename, "\n")
    cat("  Combination completed successfully!\n\n")
  }
}

# Stop parallel cluster
stopCluster(cl)

cat("=== All combinations completed! ===\n")
cat("Results saved in IFRres folder with naming convention: link_size_IFR.csv\n")
cat("Theta files saved with naming convention: link_size_IFR_thetas.csv\n\n")

# Print summary of all files created
cat("Files created:\n")
for (n in sample_sizes) {
  for (link_type in link_functions) {
    filename <- paste0(link_type, "_", n, "_IFR.csv")
    theta_filename <- paste0(link_type, "_", n, "_IFR_thetas.csv")
    cat("-", filename, "\n")
    cat("-", theta_filename, "\n")
  }
}

cat("\nLocal Frechet parallel analysis with all combinations completed!\n")
cat("Script finished at:", Sys.time(), "\n")
