
source("utility_functions_for_miav_tabpfn_iclr.R")

# To run TabPFN in R: 
# First in a terminal, create a virtual environment, activate it, and install 
# TabPFN. Then download the "generate_tabpfn_predictions.py" script.
#
# In R load reticulate and run the python script, which make classifiers 
# and regression models based on TabPFN available in R.
library(reticulate)
use_virtualenv("~/TabPFN/venv")
source_python("~/TabPFN/venv/generate_tabpfn_predictions.py")


################################################
## Generate illustrative example figures 
## (Figures 18 and 19 in Appendix J)
################################################

beta_pars_list <- list()
beta_pars_list[[1]] <- c(10, 10)
beta_pars_list[[2]] <- c(0.5, 5)
beta_pars_list[[3]] <- c(30, 3)
beta_pars_list[[4]] <- c(0.5, 0.5)
beta_pars_list[[5]] <- c(3, 10)

n <- 1000

my_seed <- 123
set.seed(my_seed) 
dat_orig <- SimulateCorrelatedBetaData(n = n, 
                                       rho = 0.75, 
                                       beta_pars_list = beta_pars_list)
set.seed(my_seed)
dat_miav <- MiavTabPFNGenerator(X = dat_orig)

set.seed(my_seed)
dat_nmiav1 <- NoisyMiavTabPFNGenerator(X = dat_orig, percent = 0.1)

set.seed(my_seed)
dat_nmiav2 <- NoisyMiavTabPFNGenerator(X = dat_orig, percent = 0.2)

set.seed(my_seed)
dat_nmiav3 <- NoisyMiavTabPFNGenerator(X = dat_orig, percent = 0.3)

dat_synt_list <- list(dat_miav,
                      dat_nmiav1,
                      dat_nmiav2,
                      dat_nmiav3)

methods_names <- c("MIAV",
                   "NMIAV_0.1",
                   "NMIAV_0.2",
                   "NMIAV_0.3")

methods_colors <- rep("red", 4)
leg_positions <- c("bottom", "topright", "topleft", "bottom", "topright")

adjust <- 1.5

X <- dat_orig
M <- ComputeAuxiliaryVariables(X)
colnames(M) <- paste0("M", seq(5))

par(mfrow = c(4, 5), mar = c(3, 2.5, 1, 0.25) + 0.1, mgp = c(1.75, 0.75, 0))
for (i in seq(5)) {
  plot(X[, i], M[, i],
       xlab = bquote(italic(X[.(i)])),
       ylab = bquote(italic(M[.(i)])),
       main = bquote(M[.(i)] ~ " vs " ~ X[.(i)]),
       cex = 0.5)
  mtext(paste0("(", letters[i], 1, ")"), side = 3, adj = 0, line = -1.2, cex = 0.8)
}
##########
percent <- 0.1
for (i in seq(5)) {
  plot(X[, i], M[, i] + rnorm(n, 0, percent*sd(M[, i])),
       xlab = bquote(italic(X[.(i)])),
       ylab = bquote(italic(M[.(i)])),
       main = bquote(M[.(i)] ~ " vs " ~ X[.(i)]),
       cex = 0.5)
  mtext(paste0("(", letters[i], 2, ")"), side = 3, adj = 0, line = -1.2, cex = 0.8)
}
##########
percent <- 0.2
for (i in seq(5)) {
  plot(X[, i], M[, i] + rnorm(n, 0, percent*sd(M[, i])),
       xlab = bquote(italic(X[.(i)])),
       ylab = bquote(italic(M[.(i)])),
       main = bquote(M[.(i)] ~ " vs " ~ X[.(i)]),
       cex = 0.5)
  mtext(paste0("(", letters[i], 3, ")"), side = 3, adj = 0, line = -1.2, cex = 0.8)
}
##########
percent <- 0.3
for (i in seq(5)) {
  plot(X[, i], M[, i] + rnorm(n, 0, percent*sd(M[, i])),
       xlab = bquote(italic(X[.(i)])),
       ylab = bquote(italic(M[.(i)])),
       main = bquote(M[.(i)] ~ " vs " ~ X[.(i)]),
       cex = 0.5)
  mtext(paste0("(", letters[i], 4, ")"), side = 3, adj = 0, line = -1.2, cex = 0.8)
}
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


par(mfrow = c(4, 5), mar = c(3, 2.5, 1, 0.25) + 0.1, mgp = c(1.75, 0.75, 0))
for (j in seq(4)) {
  for (i in seq(5)) {
    MarginalDensityPlotsQC2(var_idx = i,
                            dat_real = dat_orig,
                            dat_synt = dat_synt_list[[j]],
                            leg_pos = leg_positions[i],
                            method_name = methods_names[j],
                            method_color = methods_colors[j],
                            main = bquote(italic(X[.(i)])),
                            adjust = adjust)
    mtext(paste0("(", letters[i], j, ")"), side = 3, adj = 0)
  }
}
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


#######################################################################
## Generate real data experiments figure (Figure 20 in Appendix J)
#######################################################################

output_path <- ""
manus_path <- ""

##################################################################
## Load outputs from the simulated correlated beta distributions
## generated by the script:
## "run_simulated_data_experiments_abs_rho.R"
##################################################################

load(paste0(output_path, "simulation_outputs_abs_rho_0.95.RData"))
load(paste0(output_path, "simulation_outputs_abs_rho_0.75.RData"))
load(paste0(output_path, "simulation_outputs_abs_rho_0.5.RData"))
load(paste0(output_path, "simulation_outputs_abs_rho_0.25.RData"))
load(paste0(output_path, "simulation_outputs_abs_rho_0.RData"))


output_list <- list(out_0.95,
                    out_0.75,
                    out_0.5,
                    out_0.25,
                    out_0)

keep <- c("miav", 
          "miav_0.05", 
          "miav_0.1", 
          "miav_0.15", 
          "miav_0.2", 
          "miav_0.25", 
          "miav_0.3")


out_ks <- PoolResults(output_list,
                      keep,
                      metric_name = "ks_test_stat")

out_l2d <- PoolResults(output_list,
                       keep,
                       metric_name = "l2corr_dist")

out_ed <- PoolResults(output_list,
                      keep,
                      metric_name = "energy_dist")

out_dt <- PoolResults(output_list,
                      keep,
                      metric_name = "detection_test")

out_dcr <- PoolResults(output_list,
                       keep,
                       metric_name = "median_dcrs")

out_dbrl <- PoolResults(output_list,
                        keep,
                        metric_name = "dbrls")

out_sdid <- PoolResults(output_list,
                        keep,
                        metric_name = "sdids")

nms <- c("miav",
         "miav_0.05", 
         "miav_0.1", 
         "miav_0.15", 
         "miav_0.2", 
         "miav_0.25", 
         "miav_0.3")
nms2 <- c("MIAV",
          "NMIAV_0.05",
          "NMIAV_0.1",
          "NMIAV_0.15",
          "NMIAV_0.2",
          "NMIAV_0.25",
          "NMIAV_0.3")



####################################################################
## Load outputs from the OpenML-CC18 real data experiments
## generated by the scripts:
## "run_real_world_data_evaluations_on_first_21_datasets.R"
## "run_real_world_data_evaluations_on_additional_15_datasets.R"
####################################################################

## load the outputs from the experiments based on the 21 initial datasets
## (N <= 2000, p <= 100, categorical variables with at most 10 classes )
##
load(paste0(output_path, "real_data_evaluations_first_21_datasets_miav.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_noisy_miav_0.05.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_noisy_miav_0.1.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_noisy_miav_0.15.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_noisy_miav_0.2.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_noisy_miav_0.25.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_noisy_miav_0.3.RData"))
output_list_r1 <- list(MIAV = out_miav, 
                       NMIAV_0.05 = out_nmiav1,
                       NMIAV_0.1 = out_nmiav2,
                       NMIAV_0.15 = out_nmiav3,
                       NMIAV_0.2 = out_nmiav4,
                       NMIAV_0.25 = out_nmiav5,
                       NMIAV_0.3 = out_nmiav6)

out_ks_r1 <- PoolResultsReal(output_list_r1,
                             metric_name = "ks_stat")

out_l2d_r1 <- PoolResultsReal(output_list_r1,
                              metric_name = "l2dist")

out_ed_r1 <- PoolResultsReal(output_list_r1,
                             metric_name = "ed")

out_dt_r1 <- PoolResultsReal(output_list_r1,
                             metric_name = "detection_tests")

out_dcr_r1 <- PoolResultsReal(output_list_r1,
                              metric_name = "dcrs")

out_dbrl_r1 <- PoolResultsReal(output_list_r1,
                               metric_name = "dbrls")

out_sdid_r1 <- PoolResultsReal(output_list_r1,
                               metric_name = "sdids")


## load the outputs from the experiments based on the additional 15 datasets
## (N <= 10000, p <= 500, categorical variables with at most 10 classes )
##
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_miav.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_noisy_miav_0.05.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_noisy_miav_0.1.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_noisy_miav_0.15.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_noisy_miav_0.2.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_noisy_miav_0.25.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_noisy_miav_0.3.RData"))
output_list_r2 <- list(MIAV = out_miav, 
                       NMIAV_0.05 = out_nmiav1,
                       NMIAV_0.1 = out_nmiav2,
                       NMIAV_0.15 = out_nmiav3,
                       NMIAV_0.2 = out_nmiav4,
                       NMIAV_0.25 = out_nmiav5,
                       NMIAV_0.3 = out_nmiav6)

out_ks_r2 <- PoolResultsReal(output_list_r2,
                             metric_name = "ks_stat")

out_l2d_r2 <- PoolResultsReal(output_list_r2,
                              metric_name = "l2dist")

out_ed_r2 <- PoolResultsReal(output_list_r2,
                             metric_name = "ed")

out_dt_r2 <- PoolResultsReal(output_list_r2,
                             metric_name = "detection_tests")

out_dcr_r2 <- PoolResultsReal(output_list_r2,
                              metric_name = "dcrs")

out_dbrl_r2 <- PoolResultsReal(output_list_r2,
                               metric_name = "dbrls")

out_sdid_r2 <- PoolResultsReal(output_list_r2,
                               metric_name = "sdids")

out_ks_r <- rbind(out_ks_r1, out_ks_r2)
out_l2d_r <- rbind(out_l2d_r1, out_l2d_r2)
out_ed_r <- rbind(out_ed_r1, out_ed_r2)
out_dt_r <- rbind(out_dt_r1, out_dt_r2)
out_dcr_r <- rbind(out_dcr_r1, out_dcr_r2)
out_dbrl_r <- rbind(out_dbrl_r1, out_dbrl_r2)
out_sdid_r <- rbind(out_sdid_r1, out_sdid_r2)


######################################################################
## Load outputs from the baseline comparisons real data experiments
## generated by the script:
## "run_real_world_data_evaluations_on_baseline_comparisons.R"
######################################################################

load(paste0(output_path, "outputs_real_world_experiments_baseline_comparisons.RData"))

output_list_b <- list(out_AB,
                      out_BM,
                      out_CR,
                      out_EM,
                      out_HO,
                      out_MT,
                      out_PO)

out_ks_b <- PoolResultsB(output_list_b,
                         metric_name = "ks_stat")

out_l2d_b <- PoolResultsB(output_list_b,
                          metric_name = "l2dist")

out_ed_b <- PoolResultsB(output_list_b,
                         metric_name = "ed")

out_dt_b <- PoolResultsB(output_list_b,
                         metric_name = "detection_tests")

out_dcr_b <- PoolResultsB(output_list_b,
                          metric_name = "dcrs")

out_dbrl_b <- PoolResultsB(output_list_b,
                           metric_name = "dbrls")

out_sdid_b <- PoolResultsB(output_list_b,
                           metric_name = "sdids")


bnms <- c("miav", 
          "noisy_miav_0.05", 
          "noisy_miav_0.1", 
          "noisy_miav_0.15", 
          "noisy_miav_0.2", 
          "noisy_miav_0.25", 
          "noisy_miav_0.3")

bnms2 <- c("MIAV",
           "NMIAV_0.05",
           "NMIAV_0.1",
           "NMIAV_0.15",
           "NMIAV_0.2",
           "NMIAV_0.25",
           "NMIAV_0.3")



my_las <- 2
my_outline <- FALSE
my_line <- 0
my_cex <- 1.2

methods_color <- rep("black", length(output_list_r1))
base_color <- NULL


par(mfrow = c(3, 6), mar = c(6, 3, 1, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
boxplot(out_ks[, nms], main = "KS (simul.)", las = my_las,
        ylab = "ave. KS-statistic", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(a)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_l2d[, nms], main = "L2D (simul.)", las = my_las, 
        ylab = "L2 dist. between assoc. matrices", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(b)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dt[, nms], main = "DT (simul.)", las = my_las, 
        ylab = "detection test", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(c)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dcr[, nms], main = "DCR (simul.)", las = my_las, 
        ylab = "median of DCR distribution", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(d)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dbrl[, nms], main = "    SDBRL (simul.)", las = my_las, 
        ylab = "sorted DBRL", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(e)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_sdid[, nms], main = "  SSDID (simul.)", las = my_las, 
        ylab = "sorted SDID", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(f)", side = 3, adj = 0, line = my_line, cex = my_cex)
####
####
####
boxplot(out_ks_r, main = "KS (CC18)", las = my_las, 
        ylab = "ks-test statistic", col = "white", 
        border = methods_color, outline = my_outline)
mtext("(g)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_l2d_r, main = "L2D (CC18)", las = my_las, 
        ylab = "L2 dist. between assoc. matrices", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(h)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dt_r, main = "DT (CC18)", las = my_las, 
        ylab = "detection test", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(i)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dcr_r, main = "DCR (CC18)", las = my_las, 
        ylab = "median of DCR distribution", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(j)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dbrl_r, main = "    SDBRL (CC18)", las = my_las, 
        ylab = "sorted DBRL", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(k)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_sdid_r, main = "SSDID (CC18)", las = my_las, 
        ylab = "sorted SDID", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(l)", side = 3, adj = 0, line = my_line, cex = my_cex)
####
####
####
boxplot(out_ks_b[, bnms], main = "KS (base.)", las = my_las, 
        ylab = "ks-test statistic", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(m)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_l2d_b[, bnms], main = "L2D (base.)", las = my_las, 
        ylab = "L2 dist. between assoc. matrices", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(n)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dt_b[, bnms], main = "DT (base.)", las = my_las, 
        ylab = "detection test", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(o)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dcr_b[, bnms], main = "DCR (base.)", las = my_las, 
        ylab = "median of DCR distribution", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(p)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dbrl_b[, bnms], main = "    SDBRL (base.)", las = my_las, 
        ylab = "sorted DBRL", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(q)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_sdid_b[, bnms], main = "SSDID (base.)", las = my_las, 
        ylab = "sorted SDID", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(r)", side = 3, adj = 0, line = my_line, cex = my_cex)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))

