 library(reticulate)
library(frechet)
library(parallel)
library(doParallel)
library(foreach)

# Set Python path and import the simulation module
# Try to automatically detect Python path for Linux server
python_paths <- c(
  "/usr/bin/python3",
  "/usr/bin/python",
  "/opt/conda/bin/python",
  "/opt/miniconda3/bin/python",
  "/usr/local/bin/python3",
  "/usr/local/bin/python"
)

python_found <- FALSE
for (path in python_paths) {
  if (file.exists(path)) {
    tryCatch({
      use_python(path)
      cat("Using Python at:", path, "\n")
      python_found <- TRUE
      break
    }, error = function(e) {
      cat("Failed to use Python at:", path, "\n")
    })
  }
}

if (!python_found) {
  # Try to use reticulate's default Python
  tryCatch({
    cat("Trying reticulate's default Python...\n")
    # This will use reticulate's default Python detection
  }, error = function(e) {
    stop("No suitable Python installation found. Please install Python or specify the correct path.")
  })
}

# Test Python environment before starting parallel processing
cat("Testing Python environment...\n")
tryCatch({
  # Test basic imports
  torch <- import("torch")
  numpy <- import("numpy")
  cat("✓ torch imported successfully\n")
  cat("✓ numpy imported successfully\n")
  
  # Test simulation module
  simulation_module <- import_from_path("simulationdatagenerator", path = ".")
  cat("✓ simulationdatagenerator imported successfully\n")
  
  # Test basic functionality
  test_result <- simulation_module$generate_simulation_data_torch_true(n=10L, qf_size=10L, p=4L, link="linear", seed=1L)
  cat("✓ Test simulation run successful\n")
  
}, error = function(e) {
  cat("❌ Python environment test failed:\n")
  cat("Error:", e$message, "\n")
  cat("Python path:", py_config()$python, "\n")
  cat("Python version:", py_config()$version, "\n")
  cat("\nPlease ensure the following Python packages are installed:\n")
  cat("- torch (PyTorch)\n")
  cat("- numpy\n")
  cat("- simulationdatagenerator.py (in current directory)\n")
  cat("\nYou can install them using:\n")
  cat("pip install torch numpy\n")
  stop("Python environment test failed")
})

# Set parameters
p <- 4L
qf_size <- 100L

# Define all combinations to test
sample_sizes <- c(250L, 500L, 1250L, 2500L)
link_functions <- c("linear", "quadratic", "exp")

# Create wasserstein folder if it doesn't exist
if (!dir.exists("wasserstein")) {
  dir.create("wasserstein")
}

# Function to run a single session
run_single_session <- function(seed, n, link_type) {
  # Set Python path and import module for each worker
  # Try to automatically detect Python path for each worker
  python_paths <- c(
    "/usr/bin/python3",
    "/usr/bin/python",
    "/opt/conda/bin/python",
    "/opt/miniconda3/bin/python",
    "/usr/local/bin/python3",
    "/usr/local/bin/python"
  )
  
  python_found <- FALSE
  for (path in python_paths) {
    if (file.exists(path)) {
      tryCatch({
        use_python(path)
        python_found <- TRUE
        break
      }, error = function(e) {
        # Continue to next path
      })
    }
  }
  
  if (!python_found) {
    # Try to use reticulate's default Python
    tryCatch({
      # This will use reticulate's default Python detection
    }, error = function(e) {
      stop("No suitable Python installation found for worker.")
    })
  }
  
  # Import required Python modules
  tryCatch({
    torch <- import("torch")
    numpy <- import("numpy")
    simulation_module <- import_from_path("simulationdatagenerator", path = ".")
  }, error = function(e) {
    cat("Error importing Python modules:", e$message, "\n")
    cat("Python path being used:", py_config()$python, "\n")
    cat("Available Python packages:\n")
    tryCatch({
      py_run_string("import sys; print('\\n'.join(sys.path))")
    }, error = function(e2) {
      cat("Could not check Python path\n")
    })
    stop("Failed to import required Python modules")
  })
  
  # Generate simulation data
  result <- simulation_module$generate_simulation_data_torch_true(n=n, qf_size=qf_size, p=p, link=link_type, seed=seed)
  
  # Extract results and convert to R
  X <- result[[1]]
  Y <- result[[2]]
  theta_true <- result[[3]]
  mu <- result[[4]]
  sigma <- result[[5]]
  
  # Convert to R arrays
  X_r <- as.array(X$numpy())
  Y_r <- as.array(Y$numpy())
  theta_true_r <- as.array(theta_true$numpy())
  mu_r <- as.array(mu$numpy())
  sigma_r <- as.array(sigma$numpy())
  
  # Data splitting (same as simu_linear100.py)
  set.seed(seed)  # Set R seed for reproducible splitting
  idx <- sample(1:n)  # Random permutation
  n_train <- as.integer(0.4 * n)  # 40% for training
  n_val <- as.integer(0.1 * n)    # 10% for validation
  n_test <- 200L   # Fixed test set size
  
  idx_train <- idx[1:n_train]
  idx_val <- idx[(n_train+1):(n_train+n_val)]
  idx_test <- idx[(n_train+n_val+1):(n_train+n_val+n_test)]
  
  # Split data
  X_train <- X_r[idx_train, ]
  X_val <- X_r[idx_val, ]
  X_test <- X_r[idx_test, ]
  
  Y_train <- Y_r[idx_train, ]
  Y_val <- Y_r[idx_val, ]
  Y_test <- Y_r[idx_test, ]
  
  mu_train <- mu_r[idx_train]
  mu_val <- mu_r[idx_val]
  mu_test <- mu_r[idx_test]
  
  sigma_train <- sigma_r[idx_train]
  sigma_val <- sigma_r[idx_val]
  sigma_test <- sigma_r[idx_test]
  
  # Prepare data for GloDenReg
  # Define quantile support
  qSup <- seq(0, 1, length.out = qf_size)
  
  # Training data: xin and qin (input quantiles)
  xin <- X_train  # Training predictors (no normalization)
  qin <- lapply(1:n_train, function(i) list(x = qSup, y = as.numeric(Y_train[i, ])))
  
  # Test data: xout (output predictors)
  xout <- X_test  # Test predictors (no normalization)
  
  # Run global density regression
  res <- GloDenReg(xin = xin, qin = qin, xout = xout, optns = list(qSup = qSup))
  
  # Extract predicted quantile functions for test set
  qout <- res$qout
  
  # Calculate true quantile functions for test set
  qSup_safe <- pmax(pmin(qSup, 0.999), 0.001)  # Clamp to avoid Inf/NaN
  qf_true_test <- matrix(0, nrow = n_test, ncol = qf_size)
  
  for (i in 1:n_test) {
    sigma_i <- max(sigma_test[i], 1e-8)  # Ensure positive sigma
    qf_true_test[i, ] <- qnorm(qSup_safe, mean = mu_test[i], sd = sigma_i)
  }
  
  # Extract predicted quantile functions from qout
  qf_pred_test <- matrix(0, nrow = n_test, ncol = qf_size)
  
  # Handle different possible structures of qout
  if (is.matrix(qout)) {
    # If qout is a matrix
    qf_pred_test <- qout
  } else if (is.list(qout) && length(qout) == n_test) {
    # If qout is a list with n_test elements
    for (i in 1:n_test) {
      if (is.list(qout[[i]]) && "y" %in% names(qout[[i]])) {
        qf_pred_test[i, ] <- qout[[i]]$y
      } else {
        qf_pred_test[i, ] <- as.numeric(qout[[i]])
      }
    }
  } else if (is.numeric(qout)) {
    # If qout is a numeric vector, reshape it
    qf_pred_test <- matrix(qout, nrow = n_test, ncol = qf_size, byrow = TRUE)
  } else {
    stop("Unexpected structure of qout from GloDenReg")
  }
  
  # Calculate L2 distances between predicted and true quantile functions
  l2_distances <- numeric(n_test)
  for (i in 1:n_test) {
    l2_distances[i] <- sqrt(sum((qf_pred_test[i, ] - qf_true_test[i, ])^2)) / 10
  }
  
  avg_l2_distance <- mean(l2_distances)
  
  # Print completion message
  cat(sprintf("[%s, n=%d] Session %d completed - L2 distance: %.6f\n", 
              link_type, n, seed, avg_l2_distance))
  
  # Return results for this session
  return(list(
    seed = seed,
    avg_l2_distance = avg_l2_distance,
    l2_distances = l2_distances,
    n_test = n_test,
    theta_true = theta_true_r,
    n = n,
    link_type = link_type
  ))
}

# Set up parallel computing
num_cores <- detectCores() - 1
cl <- makeCluster(num_cores)
registerDoParallel(cl)

cat("=== Global Frechet Regression Server Script ===\n")
cat("Starting parallel computation with", num_cores, "cores\n")
cat("Running all combinations of sample sizes and link functions\n")
cat("Sample sizes:", paste(sample_sizes, collapse=", "), "\n")
cat("Link functions:", paste(link_functions, collapse=", "), "\n")
cat("Sessions per combination: 200\n")
cat("Total sessions:", length(sample_sizes) * length(link_functions) * 200, "\n\n")

# Loop through all combinations
for (n in sample_sizes) {
  for (link_type in link_functions) {
    
    cat("=== Running combination: n =", n, ", link =", link_type, "===\n")
    
    # Run parallel computation for this combination
    results <- foreach(seed = 1:200, 
                      .packages = c("reticulate", "frechet"),
                      .combine = 'c') %dopar% {
      list(run_single_session(seed, n, link_type))
    }
    
    # Extract results
    all_avg_l2_distances <- sapply(results, function(x) x$avg_l2_distance)
    all_seeds <- sapply(results, function(x) x$seed)
    
    # Create results dataframe
    results_summary <- data.frame(
      seed = all_seeds,
      avg_l2_distance = all_avg_l2_distances
    )
    
    # Save results with specified naming convention: link_size_wass.csv
    filename <- paste0("wasserstein/", link_type, "_", n, "_wass.csv")
    write.csv(results_summary, filename, row.names = FALSE)
    
    # Print summary statistics for this combination
    cat("  Mean L2 distance:", round(mean(all_avg_l2_distances), 6), "\n")
    cat("  Std L2 distance:", round(sd(all_avg_l2_distances), 6), "\n")
    cat("  Results saved to:", filename, "\n")
    cat("  Combination completed successfully!\n\n")
  }
}

# Stop parallel cluster
stopCluster(cl)

cat("=== All combinations completed! ===\n")
cat("Results saved in wasserstein/ folder with naming convention: link_size_wass.csv\n\n")

# Print summary of all files created
cat("Files created:\n")
for (n in sample_sizes) {
  for (link_type in link_functions) {
    filename <- paste0("wasserstein/", link_type, "_", n, "_wass.csv")
    cat("-", filename, "\n")
  }
}

cat("\nGlobal Frechet parallel analysis with all combinations completed!\n")
cat("Script finished at:", Sys.time(), "\n")