# Libraries and functions ----
suppressMessages(library(tidyverse))
library(ADRecommender)
library(furrr)

analyze_single_split <- function (dat_split, eps, minpts){
  d <- ADRecommender:::get_data_from_split(dat_split)
  testing <- rsample::testing(dat_split)
  res <- fit_evaluate(eps,
                      minpts,
                      d$train$performance,
                      d$train$metafeatures,
                      d$test$performance,
                      d$test$metafeatures)
  res %>% mutate(provider = rsample::testing(dat_split)$provider,
                 dataset = rsample::testing(dat_split)$dataset,
                 epsilon = eps,
                 min_points = minpts)
}

fit_evaluate <- function(eps, minpts, train_performance, train_metafeatures, test_performance, test_metafeatures){
  recommender('orthus_tunable') %>%
    fit(train_performance = train_performance,
        train_metafeatures = train_metafeatures,
        eps = eps,
        minpts = minpts) %>%
    predict(test_metafeatures) %>%
    evaluate_recommendations(., test_performance)
}


# Experimental definition

experiments <- expand.grid(eps = c(0.1, 0.2, 0.5, 1, 2, 5),
                           minpts = c(2, 5, 10, 20),
                           split = 1:30,
                           stringsAsFactors = F)

dat <- get_data_stratified_kfolds_cv(metric = 'pr_auc',
                                        mfs_metaod = F,
                                        mfs_scaled = T,
                                        mfs_catch22 = T,
                                        seed = 7777)

splits <- dat$data$splits

# Actual run ----
prepare_parallelization()
res <- future_pmap(list(experiments$split, experiments$eps, experiments$minpts),
            function(index, eps, minpts){
              filepath <- paste0('paper_scripts/appendices/ablation_cache/', index, '_', eps, '_', minpts, '.rds')
              if(file.exists(filepath)){
                res <- read_rds(filepath)
              }else{
                res <- analyze_single_split(splits[[index]], eps, minpts)
                write_rds(res, filepath, compress = 'gz')
              }
              res
            }, .progress = T)


# Analysis

res_df <- res %>% bind_rows()

res_df %>%
  group_by(epsilon, min_points, dataset) %>%
  mutate(index = 1:3) %>%
  ungroup() %>%
  pivot_wider(id_cols = c(provider, dataset, index), values_from = percentile, names_from = c(epsilon, min_points)) %>%
  select(-provider, -dataset, -index) %>%
  map_dbl(., mean)


summarized <- res_df %>%
  group_by(epsilon, min_points) %>%
  summarise(mean = mean(percentile), sd = sd(percentile)) %>%
  arrange(desc(mean))

df1 <- summarized %>%
  group_by(epsilon) %>%
  summarize(min = min(mean), max = max(mean))

df2 <- summarized %>%
  group_by(min_points) %>%
  summarize(min = min(mean), max = max(mean))

# Plots ----


p1 <- ggplot(df1) +
  geom_segment(aes(x=epsilon, xend=epsilon, y = min, yend = max), color="black") +
  geom_point( aes(x=epsilon, y=min)) +
  geom_point( aes(x=epsilon, y=max)) +
  #coord_flip()+
  theme_light() +
  #theme(
  #  legend.position = "none",
  #) +
  xlab("DBSCAN epsilon hyper-parameter") +
  ylab("Min-Max Average PRAUC Percentile") +
  scale_x_continuous(breaks = c(0.1, 0.5, 1, 2, 5))

p2 <- ggplot(df2) +
  geom_segment(aes(x=min_points, xend=min_points, y = min, yend = max), color="black") +
  geom_point( aes(x=min_points, y=min)) +
  geom_point( aes(x=min_points, y=max)) +
  #coord_flip()+
  theme_light() +
  #theme(
  #  legend.position = "none",
  #) +
  xlab("DBSCAN min points hyper-parameter") +
  ylab("Min-Max Average PRAUC Percentile") +
  scale_x_continuous(breaks = c(2, 5, 10, 20))

ggsave(filename = 'orthus_clustering_ablation.pdf',
       plot = cowplot::plot_grid(p1, p2, nrow = 1),
       device = 'pdf',
       path = here::here('paper_scripts', 'appendices'),
       width = 22,
       height = 8,
       units = 'cm')



