rm(list = ls())

source('R/experiments/analysis/metrics.R')
source('R/experiments/analysis/helper_functions.R')

library(tidyverse)

# Parameters
nmc_values <- c(250, 500, 1000, 2000, 4000)
nmc_causal_true <- 4000
nobs <- 1000
ncomb <- 1024

n_seeds <- 40
n_to_explain <- 40

n_nodes <- 11
n_parents_y <- 6
n_neighbors_values <- 2


# Initialize results tibble
results <- tibble(
  nmc = integer(),
  seed = integer(),
  explanation_type = character(),
  computation_time_full = double(),
  computation_time_full_shapr = double(),
  computation_time_compute_shapr = double(), # compute_vS + shapley_computation
  n_dags = integer()
)

# Helper function to add results to the tibble
add_results <- function(results, nmc, seed, explanation_type, computation_time_full, computation_time_full_shapr, computation_time_compute_shapr, n_dags = 1) {
  results <- results %>%
    bind_rows(tibble(
      nmc = nmc,
      seed = seed,
      explanation_type = explanation_type,
      computation_time_full = computation_time_full,
      computation_time_full_shapr = computation_time_full_shapr,
      computation_time_compute_shapr = computation_time_compute_shapr,
      n_dags = n_dags
    ))
  return(results)
}

# Main processing loop
explanation_types <- c("marginal", "conditional", "oracle_mec_sample", "oracle_mec_iw",
                       "pc_1000n_alpha0.05_mec_sample", "pc_1000n_alpha0.05_mec_iw")


for (idx_nmc in seq_along(nmc_values)) {
  nmc <- nmc_values[idx_nmc]

  dir <- file.path('experiments/6_explanations/xgboost/non-linear1/11_nodes/2_neighbors_6_parents_y')
  #n_obs_values <- ifelse(explanation_types == "causal_true", NULL, nobs) # this does not work

  for (seed in seq_len(n_seeds)) {

    explanations <- lapply(seq_along(explanation_types), function(explanation_idx) {
      current_explanation_type <- explanation_types[explanation_idx]
      current_nobs <- if (current_explanation_type == "causal_true") NULL else nobs
      current_nmc <- if (current_explanation_type == "causal_true") nmc_causal_true else nmc

      if (grepl('iw', explanation_type)) {
        current_nmc <- 1024
        current_nobs <- current_nmc
      }

      explanation_filename <- retrieve_explanations(dir, seed, current_explanation_type, current_nobs, current_nmc, ncomb,
                                                   n_to_explain, return_shapley_values = FALSE, use_first_one = TRUE)
      if (is.null(explanation_filename)) {
        return(NULL)
      } else {
        explanation <- readRDS(file.path(dir, explanation_filename))
        return(explanation)
      }
    })


    for (i in seq_along(explanations)) {
      explanation <- explanations[[i]]
      explanation_type <- explanation_types[i]
      if (!is.null(explanation)) {
        full_time_ours <- as.numeric(explanation$duration, units = "mins")
        full_time_shapr <- explanation$shapr_timings$total_time_secs   # In seconds                                                
        full_time_shapr <- full_time_shapr / 60  # In minutes
        compute_time_shapr <- as.numeric(explanation$shapr_timings$timing_secs[4]) + as.numeric(explanation$shapr_timings$timing_secs[5])
        compute_time_shapr <- compute_time_shapr / 60  # In minutes
        n_dags <- 1 # for non-mec
        if (!is.data.frame(explanation$shapley_values)) {
          n_dags <- length(explanation$shapley_values)
        }

        if (!grepl("iw", explanation_type)) {
          results <- add_results(results, nmc, seed, explanation_type, full_time_ours, full_time_shapr, compute_time_shapr, n_dags)
        } else {
          results <- add_results(results, nmc, seed, explanation_type, full_time_ours, full_time_shapr, compute_time_shapr, n_dags)
        }


      } else {
        print(paste0("Missing explanation for ", explanation_type, " for ", nmc, " nmc, seed ", seed))
      }
    }
  }
}

# rename explanation_types
results$explanation_type <- recode(results$explanation_type,
                                   "causal_true" = "causal_true",
                                   "marginal" = "marginal",
                                   "conditional" = "conditional",
                                   "oracle_mec_sample" = "oracle_sample",
                                   "oracle_mec_iw" = "oracle_iw",
                                   "pc_1024n_alpha0.05_mec_sample" = "pc_sample",
                                   "pc_1024n_alpha0.05_mec_iw" = "pc_iw",
                                   "fges_1024_mec_sample" = "fges_sample",
                                   "fges_1024_mec_iw" = "fges_iw")

write_delim(results, 'R/experiments/results/gaussian/results_times_nmc.csv', delim = ',')