args <- commandArgs(TRUE)

print(args)

seed <- as.numeric(args[[1]])
r <- as.numeric(args[[2]])
M <- as.numeric(args[[3]])
rate <- as.numeric(args[[4]])
Gender <- as.numeric(args[[5]])

library(rnhanesdata)
library(tidyverse)
library(truncnorm)
library(pracma)
library(osqp)  
library(Matrix)
library(parallel)
library(doParallel)
library(foreach)
library(doSNOW)

source("RealDataFunc.R") 

set.seed(seed)

Races = c("Black","White");race = Races[r];

data <- get_realdata(M);

realdata = data[!is.na(data$BMI) & !is.na(data$Gender) & !is.na(data$Age), ]

realdata = realdata[realdata$Race %in% Races,];realdata = realdata[realdata$Gender == Gender,]

realdata$Gender = NULL

realdata = realdata_process(realdata, M, race); 

K = length(realdata$source_data); num_fold = 5; 

z_grid <- seq(1 / M, 1 - 1 / M, length.out = M); p_grid <- seq(1 / M, 1 - 1 / M, length.out = M)

X_t <- realdata$X_t
Y_t <- realdata$Y_t
source_d <- realdata$source_data
n_t <- dim(X_t)[1]

Index_random = sample(n_t,200)
X_t = X_t[Index_random,]
Y_t = Y_t[Index_random,]

n_t <- dim(X_t)[1]

cat("Target Sample Size: ", n_t, "\n")

# Create outer fold indices
outer_folds <- create_folds(n_t, num_fold)

all_results <- list()
fold_sizes <- numeric(num_fold)

# For each outer fold
for (outer_k in 1:num_fold) 
{
  cat("===== Outer Fold:", outer_k, "=====\n")
  
  # Split data into outer training and validation sets
  outer_test_idx <- which(outer_folds == outer_k)
  X_t_train <- X_t[-outer_test_idx,]
  Y_t_train <- Y_t[-outer_test_idx,]
  X_t_test <- X_t[outer_test_idx,]
  Y_t_test <- Y_t[outer_test_idx,]
  
  n_t <- dim(X_t_train)[1]

  optns = list(kernel = "gauss", bw = c(diff(range(X_t_train$BMI))*0.15,diff(range(X_t_train$Age))*0.15))
  
  qin.target <- split(Y_t_train, row(Y_t_train))
  qin.target <- lapply(qin.target, as.numeric)
  xin.t <- to_matrix(X_t_train)
  
  cat("Dimension of X Train:",dim(X_t_train),"\n")
  cat("Dimension of Y Train:",dim(Y_t_train),"\n")
  
  n_test <- nrow(X_t_test)
  fold_sizes[outer_k] <- n_test
  fold_results <- matrix(NA, nrow = n_test, ncol = 3)
  colnames(fold_results) <- c("d_1", "d_2_target", "d_2_source")
  
  for (i in 1:n_test) 
  {
    X_value <- as.vector(as.numeric(X_t_test[i,]))
    true_Et <- Y_t_test[i,]
    
    s_vec_t_train <- LocalLinWeights(qin.target, xin.t, X_value, optns)
    
    cat("Finish Calculate Weights \n")
    
     # Compute f1_hat
    f1_hat <- compute_f1_hat(source_data_list = source_d,
                             X_value = X_value,
                             M = M,
                             rate)
    
    cat("Finish Calculate f1_hat \n")
    
    d <- compute_L2_distance(true_Et, f1_hat)
    
    # Inner cross-validation for lambda selection
    lambda_candidates <- sort(unique(c(
      seq(0.0001, 0.001, length.out = 5),
      seq(0.001, 0.01, length.out = 5),
      seq(0.01, 0.1, length.out = 5),
      seq(0.1, 0.5, length.out = 5)
    )))
    cv_out <- cv_search_lambda(
      Y_t = Y_t_train,
      s_vec = s_vec_t_train,
      X_t = X_t_train,
      source_d = source_d,
      M = M,
      rate = rate,
      lambda_grid = lambda_candidates,
      n_folds = 3,
      max_iter = 500,
      step_size = 0.1,
      tol = 1e-8,
      n_cores = 8,
      val_sample_size = 50
    )

    best_lambda <- cv_out$best_lambda
    cat("Best lambda =", best_lambda, "\n")
    
    
    # Compute f_hat using best lambda
    f_hat <- compute_f_L2(Y_t = Y_t_train,
                          s_vec = s_vec_t_train,
                          f1_hat = f1_hat,
                          lambda = 0.25,
                          M = M,
                          max_iter = 200,
                          step_size = 0.1,
                          tol = 1e-3)
    
    f_hat = Project(f_hat)
    
    # Compute d_1
    d_1 <- compute_L2_distance(true_Et, f_hat)
    
    # Compute d_2 (Target)
    f_target <- lrem(qin.target, xin.t, X_value, optns)$qp[1,]
    d_2 <- compute_L2_distance(true_Et, f_target)
    
    fold_results[i, "d_1"] <- d_1
    fold_results[i, "d_2_target"] <- d_2
    fold_results[i, "d_2_source"] <- d
    
    cat("d:",d,"d1:",d_1,"d2:",d_2,"\n")
    
    if (i %% 10 == 0 || i == n_test) {
      cat(sprintf("Processed %d/%d test points in fold %d\n", i, n_test, outer_k))
    }
  }
  
  all_results[[outer_k]] <- fold_results
  
  cat("Average results for fold", outer_k, ":\n")
  print(colMeans(fold_results, na.rm=TRUE))
}

total_points <- sum(fold_sizes)
weighted_means <- colMeans(do.call(rbind, all_results), na.rm=TRUE)
cat("\nOverall weighted averages:\n")
print(weighted_means)

all_results_matrix <- do.call(rbind, all_results)
standard_errors <- apply(all_results_matrix, 2, function(x) sd(x, na.rm=TRUE) / sqrt(nrow(all_results_matrix)))
cat("\nStandard Errors:\n")
print(standard_errors)

save(all_results, fold_sizes, weighted_means, standard_errors,
     file = paste0('data/', seed, '_', r, '_', M,'_',rate,'_', Gender, '_.RData'))
