
###############################################################################
## Script for generating Figure 4 in the main text, Figure 11 in Appendix H,
## Figures 12, 13, 14, 15, 16, and 17 in Appendix I.6, and Figures 21 and 22
## in Appendix K.
###############################################################################

# source utility functions
source("utility_functions_for_miav_tabpfn_iclr.R")

# set output and manuscript paths
output_path <- "" # path to the folder where the data is saved 
manus_path <- "" # path to the folder storing the figures


##################################################################
## Load outputs from the simulated correlated beta distributions
## generated by the script:
## "run_simulated_data_experiments_abs_rho.R"
##################################################################

load(paste0(output_path, "simulation_outputs_abs_rho_0.95.RData"))
load(paste0(output_path, "simulation_outputs_abs_rho_0.75.RData"))
load(paste0(output_path, "simulation_outputs_abs_rho_0.5.RData"))
load(paste0(output_path, "simulation_outputs_abs_rho_0.25.RData"))
load(paste0(output_path, "simulation_outputs_abs_rho_0.RData"))


output_list <- list(out_0.95,
                    out_0.75,
                    out_0.5,
                    out_0.25,
                    out_0)

keep <- c("hold", "jf", "fc", "miav", "smote")

# pool together the KS metric results from all datasets
out_ks <- PoolResults(output_list,
                      keep,
                      metric_name = "ks_test_stat")

# pool together the L2D metric results from all datasets
out_l2d <- PoolResults(output_list,
                       keep,
                       metric_name = "l2corr_dist")

# pool together the ED metric results from all datasets
out_ed <- PoolResults(output_list,
                      keep,
                      metric_name = "energy_dist")

# pool together the DT metric results from all datasets
out_dt <- PoolResults(output_list,
                      keep,
                      metric_name = "detection_test")

# pool together the DCR metric results from all datasets
out_dcr <- PoolResults(output_list,
                       keep,
                       metric_name = "median_dcrs")

# pool together the SDBRL metric results from all datasets
out_dbrl <- PoolResults(output_list,
                        keep,
                        metric_name = "dbrls")

# pool together the SSDID metric results from all datasets
out_sdid <- PoolResults(output_list,
                        keep,
                        metric_name = "sdids")

nms <- c("hold", "miav", "jf", "fc", "smote")
nms2 <- c("holdout", "MIAV", "JF", "FC", "SMOTE")


####################################################################
## Load outputs from the OpenML-CC18 real data experiments
## generated by the scripts:
## "run_real_world_data_evaluations_on_first_21_datasets.R"
## "run_real_world_data_evaluations_on_additional_15_datasets.R"
####################################################################

## load the outputs from the experiments based on the 21 initial datasets
## (N <= 2000, p <= 100, categorical variables with at most 10 classes )
##
load(paste0(output_path, "real_data_evaluations_first_21_datasets_miav.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_jf.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_fc.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_smote.RData"))
load(paste0(output_path, "real_data_evaluations_first_21_datasets_hold.RData"))
output_list_r1 <- list(holdout = out_hold,
                       MIAV = out_miav, 
                       JF = out_jf,
                       FC = out_fc,
                       SMOTE = out_smote)

# pool together the evaluation metric results from all first 21 datasets

out_ks_r1 <- PoolResultsReal(output_list_r1,
                            metric_name = "ks_stat")

out_l2d_r1 <- PoolResultsReal(output_list_r1,
                             metric_name = "l2dist")

out_ed_r1 <- PoolResultsReal(output_list_r1,
                            metric_name = "ed")

out_dt_r1 <- PoolResultsReal(output_list_r1,
                            metric_name = "detection_tests")

out_dcr_r1 <- PoolResultsReal(output_list_r1,
                             metric_name = "dcrs")

out_dbrl_r1 <- PoolResultsReal(output_list_r1,
                              metric_name = "dbrls")

out_sdid_r1 <- PoolResultsReal(output_list_r1,
                              metric_name = "sdids")



## load the outputs from the experiments based on the additional 15 datasets
## (N <= 10000, p <= 500, categorical variables with at most 10 classes )
##
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_miav.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_jf.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_fc.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_hold.RData"))
load(paste0(output_path, "real_data_evaluations_15_additional_datasets_smote.RData"))
output_list_r2 <- list(holdout = out_hold,
                       MIAV = out_miav, 
                       JF = out_jf,
                       FC = out_fc,
                       SMOTE = out_smote)

# pool together the evaluation metric results from all 15 additional datasets

out_ks_r2 <- PoolResultsReal(output_list_r2,
                             metric_name = "ks_stat")

out_l2d_r2 <- PoolResultsReal(output_list_r2,
                              metric_name = "l2dist")

out_ed_r2 <- PoolResultsReal(output_list_r2,
                             metric_name = "ed")

out_dt_r2 <- PoolResultsReal(output_list_r2,
                             metric_name = "detection_tests")

out_dcr_r2 <- PoolResultsReal(output_list_r2,
                              metric_name = "dcrs")

out_dbrl_r2 <- PoolResultsReal(output_list_r2,
                               metric_name = "dbrls")

out_sdid_r2 <- PoolResultsReal(output_list_r2,
                               metric_name = "sdids")

# concatenate the results

out_ks_r <- rbind(out_ks_r1, out_ks_r2)
out_l2d_r <- rbind(out_l2d_r1, out_l2d_r2)
out_ed_r <- rbind(out_ed_r1, out_ed_r2)
out_dt_r <- rbind(out_dt_r1, out_dt_r2)
out_dcr_r <- rbind(out_dcr_r1, out_dcr_r2)
out_dbrl_r <- rbind(out_dbrl_r1, out_dbrl_r2)
out_sdid_r <- rbind(out_sdid_r1, out_sdid_r2)


######################################################################
## Load outputs from the baseline comparisons real data experiments
## generated by the script:
## "run_real_world_data_evaluations_on_baseline_comparisons.R"
######################################################################

load(paste0(output_path, "outputs_real_world_experiments_baseline_comparisons.RData"))

output_list_b <- list(out_AB,
                      out_BM,
                      out_CR,
                      out_EM,
                      out_HO,
                      out_MT,
                      out_PO)

# pool together the evaluation metric results from all baseline comparison datasets

out_ks_b <- PoolResultsB(output_list_b,
                       metric_name = "ks_stat")

out_l2d_b <- PoolResultsB(output_list_b,
                        metric_name = "l2dist")

out_ed_b <- PoolResultsB(output_list_b,
                       metric_name = "ed")

out_dt_b <- PoolResultsB(output_list_b,
                       metric_name = "detection_tests")

out_dcr_b <- PoolResultsB(output_list_b,
                        metric_name = "dcrs")

out_dbrl_b <- PoolResultsB(output_list_b,
                         metric_name = "dbrls")

out_sdid_b <- PoolResultsB(output_list_b,
                         metric_name = "sdids")


bnms <- c("holdout", "miav", "jf", "fc", "smote",
         "ddpm", "ctgan", "tvae", "arf", "bayesnet")

bnms2 <- c("holdout", "MIAV", "JF", "FC", "SMOTE",
          "DDPM", "CTGAN", "TVAE", "ARF", "BN")



##################################################
## generate main text Figure 4
##################################################

my_las <- 2
my_line <- 0
my_cex <- 1.2

my_outline <- FALSE

methods_color <- c("green", "red", "orange", "blue", "purple")
base_color <- c(rep("black", 5))

par(mfrow = c(3, 6), mar = c(3.5, 3, 1, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
boxplot(out_ks[, nms], main = "KS (simul.)", las = my_las,
        ylab = "ave. KS-statistic", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(a)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_l2d[, nms], main = "L2D (simul.)", las = my_las, 
        ylab = "L2 dist. between assoc. matrices", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(b)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dt[, nms], main = "DT (simul.)", las = my_las, 
        ylab = "detection test", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(c)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dcr[, nms], main = "DCR (simul.)", las = my_las, 
        ylab = "median of DCR distribution", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(d)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dbrl[, nms], main = "    SDBRL (simul.)", las = my_las, 
        ylab = "sorted DBRL", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(e)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_sdid[, nms], main = "  SSDID (simul.)", las = my_las, 
        ylab = "sorted SDID", names = nms2, col = "white", 
        border = methods_color, outline = my_outline)
mtext("(f)", side = 3, adj = 0, line = my_line, cex = my_cex)
####
####
####
boxplot(out_ks_r, main = "KS (CC18)", las = my_las, 
        ylab = "ks-test statistic", col = "white", 
        border = methods_color, outline = my_outline)
mtext("(g)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_l2d_r, main = "L2D (CC18)", las = my_las, 
        ylab = "L2 dist. between assoc. matrices", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(h)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dt_r, main = "DT (CC18)", las = my_las, 
        ylab = "detection test", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(i)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dcr_r, main = "DCR (CC18)", las = my_las, 
        ylab = "median of DCR distribution", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(j)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dbrl_r, main = "    SDBRL (CC18)", las = my_las, 
        ylab = "sorted DBRL", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(k)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_sdid_r, main = "SSDID (CC18)", las = my_las, 
        ylab = "sorted SDID", col = "white", 
        border = c(methods_color), outline = my_outline)
mtext("(l)", side = 3, adj = 0, line = my_line, cex = my_cex)
####
####
####
boxplot(out_ks_b[, bnms], main = "KS (base.)", las = my_las, 
        ylab = "ks-test statistic", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(m)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_l2d_b[, bnms], main = "L2D (base.)", las = my_las, 
        ylab = "L2 dist. between assoc. matrices", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(n)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dt_b[, bnms], main = "DT (base.)", las = my_las, 
        ylab = "detection test", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(o)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dcr_b[, bnms], main = "DCR (base.)", las = my_las, 
        ylab = "median of DCR distribution", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(p)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_dbrl_b[, bnms], main = "    SDBRL (base.)", las = my_las, 
        ylab = "sorted DBRL", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(q)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_sdid_b[, bnms], main = "SSDID (base.)", las = my_las, 
        ylab = "sorted SDID", col = "white", 
        border = c(methods_color, base_color), names = bnms2, 
        outline = my_outline)
mtext("(r)", side = 3, adj = 0, line = my_line, cex = my_cex)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


################################################################
## Generate supplementary Figures 12, 13, 14, 15, 16, and 17 
################################################################

## Supplementary figures

keep <- c("hold", "miav", "jf", "fc", "smote")

boxplot(out_0.95$ks_test_stat[, keep])

my_line <- -1.5
my_cex <- 1

methods_color <- c("green", "red", "orange", "blue", "purple")

par(mfrow = c(5, 6), mar = c(3.5, 3, 1.5, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
PlotRow(out = out_0.95, my_rho = 0.95, i = 1, my_line = my_line, my_cex = my_cex)
PlotRow(out = out_0.75, my_rho = 0.75, i = 2, my_line = my_line, my_cex = my_cex)
PlotRow(out = out_0.5, my_rho = 0.5, i = 3, my_line = my_line, my_cex = my_cex)
PlotRow(out = out_0.25, my_rho = 0.25, i = 4, my_line = my_line, my_cex = my_cex)
PlotRow(out = out_0, my_rho = 0, i = 5, my_line = my_line, my_cex = my_cex)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


metric_names <- names(out_0.5)

## shape data for the first 21 datasets
dataset_names_f <- colnames(output_list_r1[[1]][[1]])
out_f1 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[1], metric_names)
out_f2 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[2], metric_names)
out_f3 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[3], metric_names)
out_f4 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[4], metric_names)
out_f5 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[5], metric_names)
out_f6 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[6], metric_names)
out_f7 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[7], metric_names)
out_f8 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[8], metric_names)
out_f9 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[9], metric_names)
out_f10 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[10], metric_names)
out_f11 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[11], metric_names)
out_f12 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[12], metric_names)
out_f13 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[13], metric_names)
out_f14 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[14], metric_names)
out_f15 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[15], metric_names)
out_f16 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[16], metric_names)
out_f17 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[17], metric_names)
out_f18 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[18], metric_names)
out_f19 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[19], metric_names)
out_f20 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[20], metric_names)
out_f21 <- ReorganizeRealResults(output_list_r1, sel_data = dataset_names_f[21], metric_names)

## shape data for the additional 15 datasets
dataset_names_a <- colnames(output_list_r2[[1]][[1]])
out_a1 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[1], metric_names)
out_a2 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[2], metric_names)
out_a3 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[3], metric_names)
out_a4 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[4], metric_names)
out_a5 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[5], metric_names)
out_a6 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[6], metric_names)
out_a7 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[7], metric_names)
out_a8 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[8], metric_names)
out_a9 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[9], metric_names)
out_a10 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[10], metric_names)
out_a11 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[11], metric_names)
out_a12 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[12], metric_names)
out_a13 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[13], metric_names)
out_a14 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[14], metric_names)
out_a15 <- ReorganizeRealResults(output_list_r2, sel_data = dataset_names_a[15], metric_names)

## map to the names in the paper table
out_d1 <- out_f1; out_d2 <- out_a1; out_d3 <- out_f2; 
out_d4 <- out_f3; out_d5 <- out_f4; out_d6 <- out_f5

out_d7 <- out_f6; out_d8 <- out_a2; out_d9 <- out_f7; 
out_d10 <- out_a3; out_d11 <- out_f8; out_d12 <- out_a4

out_d13 <- out_f9; out_d14 <- out_f10; out_d15 <- out_f11
out_d16 <- out_f12; out_d17 <- out_a5; out_d18 <- out_f13

out_d19 <- out_f14; out_d20 <- out_a6; out_d21 <- out_f15
out_d22 <- out_a7; out_d23 <- out_a8; out_d24 <- out_f16

out_d25 <- out_a9; out_d26 <- out_a10; out_d27 <- out_f17
out_d28 <- out_f18; out_d29 <- out_a11; out_d30 <- out_f19

out_d31 <- out_f20; out_d32 <- out_f21; out_d33 <- out_a12
out_d34 <- out_a13; out_d35 <- out_a14; out_d36 <- out_a15


par(mfrow = c(9, 6), mar = c(3.5, 3, 1, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
PlotRowReal(out = out_d1, i = 1, dataset_name = "D1", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d2, i = 2, dataset_name = "D2", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d3, i = 3, dataset_name = "D3", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d4, i = 4, dataset_name = "D4", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d5, i = 5, dataset_name = "D5", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d6, i = 6, dataset_name = "D6", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d7, i = 7, dataset_name = "D7", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d8, i = 8, dataset_name = "D8", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d9, i = 9, dataset_name = "D9", my_line = my_line, my_cex = my_cex)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


par(mfrow = c(9, 6), mar = c(3.5, 3, 1, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
PlotRowReal(out = out_d10, i = 1, dataset_name = "D10", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d11, i = 2, dataset_name = "D11", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d12, i = 3, dataset_name = "D12", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d13, i = 4, dataset_name = "D13", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d14, i = 5, dataset_name = "D14", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d15, i = 6, dataset_name = "D15", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d16, i = 7, dataset_name = "D16", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d17, i = 8, dataset_name = "D17", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d18, i = 9, dataset_name = "D18", my_line = my_line, my_cex = my_cex)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


par(mfrow = c(9, 6), mar = c(3.5, 3, 1, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
PlotRowReal(out = out_d19, i = 1, dataset_name = "D19", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d20, i = 2, dataset_name = "D20", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d21, i = 3, dataset_name = "D21", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d22, i = 4, dataset_name = "D22", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d23, i = 5, dataset_name = "D23", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d24, i = 6, dataset_name = "D24", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d25, i = 7, dataset_name = "D25", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d26, i = 8, dataset_name = "D26", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d27, i = 9, dataset_name = "D27", my_line = my_line, my_cex = my_cex)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


par(mfrow = c(9, 6), mar = c(3.5, 3, 1, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
PlotRowReal(out = out_d28, i = 1, dataset_name = "D28", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d29, i = 2, dataset_name = "D29", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d30, i = 3, dataset_name = "D30", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d31, i = 4, dataset_name = "D31", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d32, i = 5, dataset_name = "D32", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d33, i = 6, dataset_name = "D33", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d34, i = 7, dataset_name = "D34", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d35, i = 8, dataset_name = "D35", my_line = my_line, my_cex = my_cex)
PlotRowReal(out = out_d36, i = 9, dataset_name = "D36", my_line = my_line, my_cex = my_cex)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


par(mfrow = c(7, 6), mar = c(4, 3, 1, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
PlotRowRealB(out_AB, i = 1, dataset_name = "AB", my_line, my_cex, keep = bnms, 
             method_names = bnms2, methods_color = c(methods_color, base_color))
PlotRowRealB(out_BM, i = 2, dataset_name = "BM", my_line, my_cex, keep = bnms, 
             method_names = bnms2, methods_color = c(methods_color, base_color))
PlotRowRealB(out_CR, i = 3, dataset_name = "CR", my_line, my_cex, keep = bnms, 
             method_names = bnms2, methods_color = c(methods_color, base_color))
PlotRowRealB(out_EM, i = 4, dataset_name = "EM", my_line, my_cex, keep = bnms, 
             method_names = bnms2, methods_color = c(methods_color, base_color))
PlotRowRealB(out_HO, i = 5, dataset_name = "HO", my_line, my_cex, keep = bnms, 
             method_names = bnms2, methods_color = c(methods_color, base_color))
PlotRowRealB(out_MT, i = 6, dataset_name = "MT", my_line, my_cex, keep = bnms, 
             method_names = bnms2, methods_color = c(methods_color, base_color))
PlotRowRealB(out_PO, i = 7, dataset_name = "PO", my_line, my_cex, keep = bnms, 
             method_names = bnms2, methods_color = c(methods_color, base_color))
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


############################################################################
## Load outputs for the categorical real data experiments comparing
## TabICL and TabPFN. The outputs were generated using the script:
## "run_real_world_data_evaluations_on_categorical_datasets.R"
############################################################################

load(paste0(output_path, "outputs_real_data_rexperiments_categorical.RData"))

output_list_r <- list(houldout = out_hold,
                      MIAV_TabPFN = out_tabpfn_miav, 
                      MIAV_TabICL = out_tabicl_miav,
                      JF_TabPFN = out_tabpfn_jf,
                      JF_TabICL = out_tabicl_jf,
                      FC_TabPFN = out_tabpfn_fc,
                      FC_TabICL = out_tabicl_fc)

metric_names <- c("kl_dive", "l2dist")
dataset_names <- colnames(output_list_r[[1]][[1]])

out_d1 <- ReorganizeRealResults(output_list_r, sel_data = dataset_names[1], metric_names)
out_d2 <- ReorganizeRealResults(output_list_r, sel_data = dataset_names[2], metric_names)
out_d3 <- ReorganizeRealResults(output_list_r, sel_data = dataset_names[3], metric_names)
out_d4 <- ReorganizeRealResults(output_list_r, sel_data = dataset_names[4], metric_names)
out_d5 <- ReorganizeRealResults(output_list_r, sel_data = dataset_names[5], metric_names)
out_d6 <- ReorganizeRealResults(output_list_r, sel_data = dataset_names[6], metric_names)
out_d7 <- ReorganizeRealResults(output_list_r, sel_data = dataset_names[7], metric_names)
out_d8 <- ReorganizeRealResults(output_list_r, sel_data = dataset_names[8], metric_names)


my_las <- 2
my_line <- 0
my_cex <- 1.2

# generate Figure 21

par(mfrow = c(2, 4), mar = c(6.5, 3, 1.5, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
boxplot(out_d1$kl_dive, las = my_las, main = "C1", ylab = "average KL-divergence")
mtext("(a)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d2$kl_dive, las = my_las, main = "C2", ylab = "average KL-divergence")
mtext("(b)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d3$kl_dive, las = my_las, main = "C3", ylab = "average KL-divergence")
mtext("(c)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d4$kl_dive, las = my_las, main = "C4", ylab = "average KL-divergence")
mtext("(d)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d5$kl_dive, las = my_las, main = "C5", ylab = "average KL-divergence")
mtext("(e)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d6$kl_dive, las = my_las, main = "C6", ylab = "average KL-divergence")
mtext("(f)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d7$kl_dive, las = my_las, main = "C7", ylab = "average KL-divergence")
mtext("(g)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d8$kl_dive, las = my_las, main = "C8", ylab = "average KL-divergence")
mtext("(h)", side = 3, adj = 0, line = my_line, cex = my_cex)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))


# generate Figure 22

par(mfrow = c(2, 4), mar = c(6.5, 3, 1.5, 0.5) + 0.1, mgp = c(2.25, 0.75, 0))
boxplot(out_d1$l2dist, las = my_las, main = "C1", ylab = "L2 dist. between assoc. matrices")
mtext("(a)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d2$l2dist, las = my_las, main = "C2", ylab = "L2 dist. between assoc. matrices")
mtext("(b)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d3$l2dist, las = my_las, main = "C3", ylab = "L2 dist. between assoc. matrices")
mtext("(c)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d4$l2dist, las = my_las, main = "C4", ylab = "L2 dist. between assoc. matrices")
mtext("(d)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d5$l2dist, las = my_las, main = "C5", ylab = "L2 dist. between assoc. matrices")
mtext("(e)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d6$l2dist, las = my_las, main = "C6", ylab = "L2 dist. between assoc. matrices")
mtext("(f)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d7$l2dist, las = my_las, main = "C7", ylab = "L2 dist. between assoc. matrices")
mtext("(g)", side = 3, adj = 0, line = my_line, cex = my_cex)
boxplot(out_d8$l2dist, las = my_las, main = "C8", ylab = "L2 dist. between assoc. matrices")
mtext("(h)", side = 3, adj = 0, line = my_line, cex = my_cex)



#############################################################
## Load outputs of the runtime benchmarking experiments. 
## The outputs were generated using the python notebook:
## "runtime_benchmarking.ipynb"
#############################################################

out_miav_1000 <- read.csv(paste0(output_path, "time_bench_miav_n_1000.csv"), header = TRUE)
out_jf_1000 <- read.csv(paste0(output_path, "time_bench_jf_n_1000.csv"), header = TRUE)
out_fc_1000 <- read.csv(paste0(output_path, "time_bench_fc_n_1000.csv"), header = TRUE)

out_miav_2000 <- read.csv(paste0(output_path, "time_bench_miav_n_2000.csv"), header = TRUE)
out_jf_2000 <- read.csv(paste0(output_path, "time_bench_jf_n_2000.csv"), header = TRUE)
out_fc_2000 <- read.csv(paste0(output_path, "time_bench_fc_n_2000.csv"), header = TRUE)

out_miav_3000 <- read.csv(paste0(output_path, "time_bench_miav_n_3000.csv"), header = TRUE)
out_jf_3000 <- read.csv(paste0(output_path, "time_bench_jf_n_3000.csv"), header = TRUE)
out_fc_3000 <- read.csv(paste0(output_path, "time_bench_fc_n_3000.csv"), header = TRUE)

xaxis <- seq(10, 100, by = 10)

# generate Figure 11

par(mfrow = c(1, 4), mar = c(3.5, 3, 1, 0.5) + 0.1, mgp = c(2.0, 0.75, 0))
out_miav <- out_miav_1000
out_jf <- out_jf_1000
out_fc <- out_fc_1000
m_miav <- apply(out_miav, 2, median)
m_jf <- apply(out_jf, 2, median)
m_fc <- apply(out_fc, 2, median)
boxplot(out_miav, ylim = c(5, 900), border = "red", col = "white",
        names = xaxis, main = "n = 1,000", xlab = "number of features",
        ylab = "runtime (in secs)")
lines(m_miav, type = "b", col = "red")
boxplot(out_jf, add = TRUE, border = "orange", col = "white", xaxt = "n")
lines(m_jf, type = "b", col = "orange")
boxplot(out_fc, add = TRUE, border = "blue", col = "white", xaxt = "n")
lines(m_fc, type = "b", col = "blue")
legend("topleft", legend = c("FC", "JF", "MIAV"), 
       text.col = c("blue", "orange", "red"), bty = "n")
mtext("(a)", side = 3, adj = 0)
####
####
out_miav <- out_miav_2000
out_jf <- out_jf_2000
out_fc <- out_fc_2000
m_miav <- apply(out_miav, 2, median)
m_jf <- apply(out_jf, 2, median)
m_fc <- apply(out_fc, 2, median)
boxplot(out_miav, ylim = c(5, 900), border = "red", col = "white",
        names = xaxis, main = "n = 2,000", xlab = "number of features",
        ylab = "runtime (in secs)")
lines(m_miav, type = "b", col = "red")
boxplot(out_jf, add = TRUE, border = "orange", col = "white", xaxt = "n")
lines(m_jf, type = "b", col = "orange")
boxplot(out_fc, add = TRUE, border = "blue", col = "white", xaxt = "n")
lines(m_fc, type = "b", col = "blue")
legend("topleft", legend = c("FC", "JF", "MIAV"), 
       text.col = c("blue", "orange", "red"), bty = "n")
mtext("(b)", side = 3, adj = 0)
####
####
out_miav <- out_miav_3000
out_jf <- out_jf_3000
out_fc <- out_fc_3000
m_miav <- apply(out_miav, 2, median)
m_jf <- apply(out_jf, 2, median)
m_fc <- apply(out_fc, 2, median)
boxplot(out_miav, ylim = c(5, 900), border = "red", col = "white",
        names = xaxis, main = "n = 3,000", xlab = "number of features",
        ylab = "runtime (in secs)")
lines(m_miav, type = "b", col = "red")
boxplot(out_jf, add = TRUE, border = "orange", col = "white", xaxt = "n")
lines(m_jf, type = "b", col = "orange")
boxplot(out_fc, add = TRUE, border = "blue", col = "white", xaxt = "n")
lines(m_fc, type = "b", col = "blue")
legend("topleft", legend = c("FC", "JF", "MIAV"), 
       text.col = c("blue", "orange", "red"), bty = "n")
mtext("(c)", side = 3, adj = 0)
####
####
m_miav_1000 <- apply(out_miav_1000, 2, median)
m_miav_2000 <- apply(out_miav_2000, 2, median)
m_miav_3000 <- apply(out_miav_3000, 2, median)
plot(xaxis, m_miav_1000, col = "red", type = "l", ylim = c(5, 100), lty = 3,
     main = "MIAV", ylab = "runtime (in secs)", xlab = "number of features")
points(xaxis, m_miav_2000, col = "red", type = "l", lty = 2)
points(xaxis, m_miav_3000, col = "red", type = "l", lty = 1)
legend("topleft", legend = c("n = 3,000", "n = 2,000", "n = 1,000"), 
       lty = c(1, 2, 3), bty = "n")
mtext("(d)", side = 3, adj = 0)
par(mfrow = c(1, 1), mar = c(5, 4, 4, 2) + 0.1, mgp = c(3, 1, 0))





