rm(list=ls())
set.seed(20250410)
timestart <- Sys.time()


source("functions.R")
source("data_gen.R")
source("true_gen.R")
source("run_online.R")
source("run_offline.R")

library(MASS)
library(ggplot2)
library(parallel)

# =============================
# Parameters
# =============================
n <- 10000
rep_num <- 200
eps_privacy <- 3
delta_privacy <- 0.1
noise_type <- 't' # 'Gaussian', 't'
sigma2 <- 0.25
df <- 3

gamma_type <- 'constant' # 'constant', 'non-constant'
if (gamma_type == 'constant') {
  zeta <- 0.6
  gamma0 <- 2 * n ^ (-zeta)
  gamma_exp <- 0
} else {
  gamma0 <- 4
  gamma_exp <- 0.6
}


lent <- 101
grid <- seq(0, 1, length.out = lent)
n_drop <- ceiling(0.1 * n) - 1 # 999
record_num <- 10
record_size <- seq(n_drop + 1, n, length.out = record_num)
true_f0 <- f_true1 # f_true1, f_true2
kernel <- function(x, y) {Rep_K(x, y, h=0.2)}

Ker_cov <- outer(grid, grid, kernel)
K_grid <- Ker_cov + diag(1e-8, lent)
Kinv_grid <- solve(K_grid)


# =============================
# Choice of tau
# =============================
data_tmp <- gene_data(1000, true_f0, noise_type, sigma2, df)
X_all_tmp <- data_tmp$X
Y_all_tmp <- data_tmp$Y
res_L2_non_tmp <- online_updating(X_all_tmp, Y_all_tmp, gamma0, gamma_exp, kernel, inner_product_H, grid, eps_privacy, delta_privacy, tau, n_drop, 'L2', 'non')
res_all <- rep(1000, 1000)
for (i in 1:1000) {
  x <- X_all_tmp[i]
  y_hat <- res_L2_non_tmp[[1]][which.min(abs(grid - x))]
  res_all[i] <- Y_all_tmp[i] - y_hat
}

sigma0_hat   <- median(abs(res_all - median(res_all))) / 0.6745
tau <- sigma0_hat * 1.345



# =============================
# Main experiments
# =============================
dir.create("results", showWarnings = FALSE)
dir.create("results/plots", showWarnings = FALSE)

num_cores <- detectCores() - 5
cl <- makeCluster(num_cores)
cat("Start parallel computing, core number: ", num_cores, "\n")

clusterExport(cl, varlist = c("n", "lent", "record_num", "record_size", "n_drop", "sigma2", "df", "tau", "noise_type", 
                              "eps_privacy", "delta_privacy", "gamma0", "gamma_exp", "grid", "K_grid", 
                              "Ker_cov", "kernel", "gene_data", "true_f0", "mvrnorm", "Rep_K", "Kinv_grid", 
                              "inner_product_H", "L2_MSE_compute", "online_updating", "offline_updating"))
clusterEvalQ(cl, {
  library(MASS)
})


results <- parLapply(cl, 1:rep_num, function(tt) {
  set.seed(20250410 + tt)
  data_all <- gene_data(n, true_f0, noise_type, sigma2, df)
  X_all <- data_all$X
  Y_all <- data_all$Y
  
  timestart1 <- Sys.time()
  res_huber_non <- online_updating(X_all, Y_all, gamma0, gamma_exp, kernel, inner_product_H, grid, eps_privacy=2, delta_privacy=0.2, tau, n_drop, 'huber', 'non')
  timeend1 <- Sys.time()
  time_huber_non <- timeend1 - timestart1
  
  timestart2 <- Sys.time()
  res_L2_non <- online_updating(X_all, Y_all, gamma0, gamma_exp, kernel, inner_product_H, grid, eps_privacy, delta_privacy, tau, n_drop, 'L2', 'non')
  timeend2 <- Sys.time()
  time_L2_non <- timeend2 - timestart2

  timestart3 <- Sys.time()
  res_offline <- offline_updating(X_all, Y_all, gamma0, gamma_exp, kernel, grid, n_drop)
  timeend3 <- Sys.time()
  time_offline <- timeend3 - timestart3
  
  # timestart4 <- Sys.time()
  # res_huber_DP202 <- online_updating(X_all, Y_all, gamma0, gamma_exp, kernel, inner_product_H, grid, eps_privacy=2, delta_privacy=0.2, tau, n_drop, 'huber', 'DP')
  # timeend4 <- Sys.time()
  # time_huber_DP202 <- timeend4 - timestart4
  
  # timestart5 <- Sys.time()
  # res_huber_DP301 <- online_updating(X_all, Y_all, gamma0, gamma_exp, kernel, inner_product_H, grid, eps_privacy=3, delta_privacy=0.1, tau, n_drop, 'huber', 'DP')
  # timeend5 <- Sys.time()
  # time_huber_DP301 <- timeend5 - timestart5
  
  
  L2_MSE_huber_non <- L2_MSE_compute(res_huber_non, true_f0, grid, 'online')
  L2_MSE_L2_non <- L2_MSE_compute(res_L2_non, true_f0, grid, 'online')
  L2_MSE_offline <- L2_MSE_compute(res_offline, true_f0, grid, 'offline')
  # L2_MSE_huber_DP202 <- L2_MSE_compute(res_huber_DP202, true_f0, grid, 'online')
  # L2_MSE_huber_DP301 <- L2_MSE_compute(res_huber_DP301, true_f0, grid, 'online')
  
  list(res_huber_non=res_huber_non, res_L2_non=res_L2_non, res_offline=res_offline,
       L2_MSE_huber_non=L2_MSE_huber_non, L2_MSE_L2_non=L2_MSE_L2_non, L2_MSE_offline=L2_MSE_offline,
       time_huber_non=time_huber_non, time_L2_non=time_L2_non, time_offline=time_offline)
  # list(res_huber_non=res_huber_non, res_huber_DP202=res_huber_DP202, res_huber_DP301=res_huber_DP301,
  #      L2_MSE_huber_non=L2_MSE_huber_non, L2_MSE_huber_DP202=L2_MSE_huber_DP202, L2_MSE_huber_DP301=L2_MSE_huber_DP301,
  #      time_huber_non=time_huber_non, time_huber_DP202=time_huber_DP202, time_huber_DP301=time_huber_DP301)
})
stopCluster(cl)
cat("Parallel computing completed. \n")


saveRDS(results, file = "results/results_set1t3n1wrep200cons206h02.rds")

cat("\n All done! Total time consumption: \n")
print(Sys.time() - timestart)




# =============================
# Sensitivity of step sizes
# =============================
zeta_vec <- c(0.3, 0.4, 0.5, 0.6, 0.7, 0.8)
gamma_exp_vec <- c(4, 8, 12, 16, 20, 24)
n_rep <- 50
gamma_type <- 'constant' # 'constant', 'non-constant'

library(foreach)
library(doParallel)

n_cores <- parallel::detectCores() - 5
cl <- makeCluster(n_cores)
registerDoParallel(cl)

mse_all <- matrix(NA, length(zeta_vec), length(gamma_exp_vec))

for (i in 1:length(zeta_vec)) {
  for (j in 1:length(gamma_exp_vec)) {
    zeta <- zeta_vec[i]
    gamma_exp1 <- gamma_exp_vec[j]
    
    if (gamma_type == 'constant') {
      gamma0 <- gamma_exp1 * n ^ (-zeta)
      gamma_exp <- 0
    } else {
      gamma0 <- gamma_exp1
      gamma_exp <- zeta
    }
    
    mse_vec <- foreach(b = 1:n_rep, .combine = c) %dopar% {
      set.seed(20250410 + b)
      data_all <- gene_data(n, true_f0, noise_type, sigma2, df)
      X_all <- data_all$X
      Y_all <- data_all$Y
      
      res_huber_non <- online_updating(X_all, Y_all, gamma0, gamma_exp,
                                       kernel, inner_product_H, grid,
                                       eps_privacy = 2, delta_privacy = 0.2,
                                       tau, n_drop, 'huber', 'non')
      
      L2_MSE_compute(res_huber_non, true_f0, grid, 'online')[10]
    }
    
    mse_all[i, j] <- mean(mse_vec)
    cat(sprintf("Finished (%d, %d): MSE = %.4f\n", i, j, mse_all[i, j]))
  }
}

stopCluster(cl)

library(ggplot2)

mse_long <- data.frame(
  zeta = rep(zeta_vec, each = length(gamma_exp_vec)),
  gamma_exp = rep(gamma_exp_vec, times = length(zeta_vec)),
  MSE = as.vector(mse_all)
)

ggplot(mse_long, aes(x = gamma_exp, y = zeta, fill = MSE)) +
  geom_tile(color = "white") +
  scale_fill_gradient(low = "yellow", high = "red") +
  labs(
    title = "Heatmap of MSE using H-FSGD with constant step size",
    x = expression(gamma[0]),
    y = expression(zeta),
    fill = "MSE"
  ) +
  theme_minimal(base_size = 14)




# ===================================
# Robustness in contamination models
# ===================================
n <- 10000
n_rep <- 50
eps_privacy <- 2
delta_privacy <- 0.2
noise_type <- 'Gaussian' # 'Gaussian', 't'
sigma2 <- 0.25
df <- 0.3

gamma_type <- 'constant' # 'constant', 'non-constant'
if (gamma_type == 'constant') {
  zeta <- 0.6 # 0.2, 0.4, 0.6, 0.8, 1.0, 1.5, 2.0
  gamma0 <- 20 * n ^ (-zeta)
  gamma_exp <- 0
} else {
  gamma0 <- 4
  gamma_exp <- 0.6 # 0.2, 0.4, 0.6, 0.8, 1.0, 1.5, 2.0
}

lent <- 101
grid <- seq(0, 1, length.out = lent)
n_drop <- ceiling(0.1 * n) - 1 # 999
record_num <- 10
record_size <- seq(n_drop + 1, n, length.out = record_num)
true_f0 <- f_true1
noise_f <- f_true2
kernel <- function(x, y) {Rep_K(x, y, h=0.2)}

Ker_cov <- outer(grid, grid, kernel)
K_grid <- Ker_cov + diag(1e-8, lent)
Kinv_grid <- solve(K_grid)

data_tmp <- gene_data_con(1000, true_f0, noise_f, eps_con=0, noise_type, sigma2, df)
X_all_tmp <- data_tmp$X
Y_all_tmp <- data_tmp$Y
res_L2_non_tmp <- online_updating(X_all_tmp, Y_all_tmp, gamma0, gamma_exp, kernel, inner_product_H, grid, eps_privacy, delta_privacy, tau, n_drop, 'L2', 'non')
res_all <- rep(1000, 1000)
for (i in 1:1000) {
  x <- X_all_tmp[i]
  y_hat <- res_L2_non_tmp[[1]][which.min(abs(grid - x))]
  res_all[i] <- Y_all_tmp[i] - y_hat
}

sigma0_hat   <- median(abs(res_all - median(res_all))) / 0.6745
tau <- sigma0_hat * 1.345 # normal

n_cores <- parallel::detectCores() - 5
cl <- makeCluster(n_cores)
registerDoParallel(cl)

eps_con_list <- seq(0, 1, length.out=11)[-11]
mse_all <- matrix(NA, 4, 10)

for (i in 1:length(eps_con_list)) {
  eps_con <- eps_con_list[i]
  mse_vec <- foreach(b = 1:n_rep, .combine = c) %dopar% {
    set.seed(20250410 + b)
    data_all <- gene_data_con(n, true_f0, noise_f, eps_con, noise_type, sigma2, df)
    X_all <- data_all$X
    Y_all <- data_all$Y
    
    res_huber_non <- online_updating(X_all, Y_all, gamma0, gamma_exp,
                                     kernel, inner_product_H, grid,
                                     eps_privacy = 2, delta_privacy = 0.2,
                                     tau, n_drop, 'huber', 'non')
    res_L2_non <- online_updating(X_all, Y_all, gamma0, gamma_exp, 
                                  kernel, inner_product_H, grid, 
                                  eps_privacy, delta_privacy, 
                                  tau, n_drop, 'L2', 'non')
    mse_huber <- L2_MSE_compute(res_huber_non, true_f0, grid, 'online')[10]
    mse_L2 <- L2_MSE_compute(res_L2_non, true_f0, grid, 'online')[10]
    c(mse_huber, mse_L2)
  }
  
  num_huber_idx <- seq(1, (n_rep*2), by = 2)
  mse_huber_vec <- mse_vec[num_huber_idx]
  mse_L2_vec <- mse_vec[-num_huber_idx]
  mse_all[1, i] <- mean(mse_huber_vec)
  mse_all[2, i] <- sd(mse_huber_vec)
  mse_all[3, i] <- mean(mse_L2_vec)
  mse_all[4, i] <- sd(mse_L2_vec)
  cat(sprintf("Finished %d\n", i))
}

stopCluster(cl)

format_row <- function(mean_vec, std_vec, idx, scale = 100, sig = 3) {
  m <- signif(mean_vec[idx] * scale, sig)
  s <- signif(std_vec[idx]  * scale, sig)
  sprintf("%.*g (%.*g)", sig, m, sig, s)
}

row1 <- format_row(mse_all[1,], mse_all[2,], 1:10)
row2 <- format_row(mse_all[3,], mse_all[4,], 1:10)
result_mat <- rbind(row1, row2)
result_mat
