args <- commandArgs(TRUE)

print(args)

M <- as.numeric(args[[1]])
n_t <- as.numeric(args[[2]])
seed <- as.numeric(args[[3]])
setting <- as.numeric(args[[4]])
ns <- as.numeric(args[[5]])

library(truncnorm)
library(pracma)
library(Matrix)
library(osqp)
library(parallel)
library(doParallel)
library(foreach)

source("SimulationFunc.R") 

set.seed(seed)

K  <- 5; n_vec <- c(100, 200, 300, 250, 150) * ns

nRepeat <- 50

z_grid <- seq(1 / M, 1 - 1/ M, length.out = M); p_grid <- seq(1 / M, 1 - 1 / M, length.out = M)

lossMat <- matrix(NA, nrow = nRepeat, ncol = 4)
colnames(lossMat) <- c("d_1", "d_2_target", "d_2_source", "d")

if(setting == 1){
  X_values <- seq(0, 1, length.out = 100)
} else if(setting == 2){
  X_values <- seq(-1, 1, length.out = 100)
}

cl <- makeCluster(10)
registerDoParallel(cl)

all_results <- array(NA, dim = c(length(X_values), nRepeat, 4))
dimnames(all_results)[[3]] <- c("d_1", "d_2_target", "d_2_source", "d")

z_grid <- seq(1/M, 1 - 1/M, length.out = M)
p_grid <- seq(1/M, 1 - 1/M, length.out = M)

clusterExport(cl, c("n_t", "n_vec", "M", "K", "setting", "z_grid", "p_grid"))

for (x_idx in 1:length(X_values)) {
  X_value <- X_values[x_idx]
  cat(sprintf("\nProcessing X_value = %.3f (%d/%d)\n", 
              X_value, x_idx, length(X_values)))
  
  results <- foreach(iter = 1:nRepeat, 
                     .combine = 'rbind',
                     .packages = c('truncnorm', 'pracma', 'Matrix', 'osqp')) %dopar% {
                       
                       if(setting == 1){
                         simu_data <- simulate_data_setting1(n_t, n_vec, M, K)
                       } else if(setting == 2) {
                         simu_data <- simulate_data_setting2(n_t, n_vec, M, K)
                       }
                       
                       X_t <- simu_data$X_t
                       Y_t <- simu_data$Y_t
                       source_d <- simu_data$source_data
                       
                       f1_hat <- compute_f1_hat(source_data_list = source_d,
                                                X_value = X_value,
                                                M = M)
                       
                       X_t_mat <- to_matrix(X_t)
                       xbar <- colMeans(X_t_mat)
                       Sigma <- cov(X_t_mat) * (nrow(X_t_mat)-1)/nrow(X_t_mat)
                       invSigma <- solve(Sigma)
                       
                       s_vec_t <- sapply(1:nrow(X_t_mat), function(i){
                         as.numeric(1 + (X_t_mat[i,] - xbar) %*% invSigma %*% (X_value - xbar))
                       })
                       
                       F_inv_Yt <- t(apply(Y_t, 1, function(row) {
                         F_inv_t_i_fun <- approxfun(p_grid, row, method = "linear", rule = 2)
                         F_inv_t_i_fun(z_grid)
                       }))
                       
                       d_k_vec <- numeric(K)
                       for(k in 1:K) {
                         X_s <- source_d[[k]]$X_s
                         X_s_mat <- to_matrix(X_s)
                         xbar_s <- colMeans(X_s_mat)
                         Sigma_s <- cov(X_s_mat) * (nrow(X_s_mat)-1)/nrow(X_s_mat)
                         invSigma_s <- solve(Sigma_s)
                         
                         s_vec_s <- sapply(1:nrow(X_s_mat), function(i){
                           as.numeric(1 + (X_s_mat[i,] - xbar_s) %*% invSigma_s %*% (X_value - xbar_s))
                         })
                         
                         Y_s <- source_d[[k]]$Y_s
                         F_inv_Ys <- t(apply(Y_s, 1, function(row) {
                           F_inv_s_i_fun <- approxfun(p_grid, row, method = "linear", rule = 2)
                           F_inv_s_i_fun(z_grid)
                         }))
                         
                         diff_k <- colMeans(s_vec_s * F_inv_Ys) - colMeans(s_vec_t * F_inv_Yt)
                         d_k_vec[k] <- sqrt(sum(diff_k^2))
                       }
                       
                       d <- max(d_k_vec)
                       
                       if(setting == 1){
                         lambda_candidates <- seq(0, 3, by = 0.1)
                       } else if(setting == 2){
                         lambda_candidates <- seq(0, 0.05, by = 0.01)
                       }
                      cv_out <- cv_search_lambda(
                        Y_t = Y_t,
                        s_vec = s_vec_t,
                        X_t = X_t,
                        source_d = source_d,
                        M = M,
                        lambda_grid = lambda_candidates,
                        n_folds = 3,
                        max_iter = 500,
                        step_size = 0.1,
                        tol = 1e-8,
                        n_cores = 8,
                        val_sample_size = 50
                      )
                       
                       best_lambda <- cv_out$best_lambda
                       
                       f_hat <- compute_f_L2(Y_t = Y_t,
                                             s_vec = s_vec_t,
                                             f1_hat = f1_hat,
                                             lambda = best_lambda,
                                             M = M,
                                             max_iter = 1000,
                                             step_size = 0.25,
                                             tol = 1e-8)
                       
                       if(setting == 1){
                         true_Et <- (1 - X_value)*z_grid + X_value*qtruncnorm(z_grid, a=0, b=1, mean=0.5, sd=1)
                       } else if(setting == 2) {
                         true_Et <- 3 * X_value + (3 + 3 * X_value) * qtruncnorm(z_grid, a=0, b=1, mean=0.5, sd=1)
                       }
                       
                       d_1 <- compute_L2_distance(true_Et, f_hat)
                       
                       qin.target <- split(Y_t, row(Y_t))
                       qin.target <- lapply(qin.target, as.numeric)
                       xin.t <- to_matrix(X_t)
                       f_target <- grem(qin.target, xin.t, X_value)$qp[1,]
                       d_2 <- compute_L2_distance(true_Et, f_target)
                       
                       X_s_all <- NULL
                       Y_s_all <- NULL
                       for(k in 1:K) {
                         X_s_all <- c(source_d[[k]]$X_s, X_s_all)
                         Y_s_all <- rbind(source_d[[k]]$Y_s, Y_s_all)
                       }
                       
                       qin.source <- split(Y_s_all, row(Y_s_all))
                       qin.source <- lapply(qin.source, as.numeric)
                       xin.s <- to_matrix(X_s_all)
                       f_source <- grem(qin.source, xin.s, X_value)$qp[1,]
                       d_3 <- compute_L2_distance(true_Et, f_source)
                       
                       c(d_1, d_2, d_3, d)
                     }
  
  all_results[x_idx, , ] <- results
  
  cat(sprintf("Average for X_value = %.3f:\n", X_value))
  print(colMeans(results))
}

stopCluster(cl)


cat("\nOverall averages:\n")
overall_means <- apply(all_results, 3, mean, na.rm = TRUE)
print(overall_means)


dir.create(paste0("Setting", setting), recursive = TRUE, showWarnings = FALSE)

save(all_results, X_values, 
     file = paste0('Setting', setting, "/", 
                  M, '_', n_t, '_', seed, '_', ns, '.RData'))