source('R/scm/graph_functions.R')
source('R/scm/node_functions.R')
source('R/scm/binary_scm.R')
source('R/scm/additive_noise_scm.R')
source('R/scm/linear_gaussian_scm.R')

source('R/experiments/helper_functions.R')

library(xgboost)

train_model <- function(data_path, n_training_points = 10^4) {
  file_path <- data_path
  log_to_shell(paste0("Training model for ", file_path, "..."))
  data_and_parameters <- readRDS(file_path)
  observational_data <- data_and_parameters$data_observational

  scm_type <- data_and_parameters$scm_type
  n_nodes <- data_and_parameters$n_nodes
  n_expected_neighbors <- data_and_parameters$n_expected_neighbors
  n_expected_parents_y <- data_and_parameters$n_expected_parents_y
  seed <- data_and_parameters$seed

  n_nodes <- ncol(observational_data)

  x_training_model <- observational_data[seq(n_training_points), seq(n_nodes - 1)]
  y_training_model <- observational_data[seq(n_training_points), n_nodes]

  set.seed(seed)
  model_type <- "xgboost"
  eta <- 0.1
  nrounds <- 100
  subsample <- 0.8
  model_to_explain <- xgboost(data = x_training_model,
                              label = y_training_model,
                              eta = eta,
                              nrounds = nrounds,
                              subsample = subsample,
                              verbose = FALSE)
  log_to_shell(paste0("Model trained. Saving model for ", file_path, "..."))
  save_model(model_to_explain, model_type, scm_type, file_path, n_training_points,
             n_nodes, n_expected_neighbors, n_expected_parents_y, seed,
             eta, nrounds, subsample)
  log_to_shell(paste0("Model saved for ", file_path, "."))
}

train_all_models <- function(n_training_points = 10^4) {
  data_base_folder <- "experiments/3_data/"
  scm_type_folders <- list.dirs(data_base_folder, recursive = FALSE)
  for (scm_type_folder in scm_type_folders) {
    n_nodes_folders <- list.dirs(scm_type_folder, recursive = FALSE)
    for (n_nodes_folder in n_nodes_folders) {
      n_nodes <- basename(n_nodes_folder)
      n_nodes <- as.integer(gsub("\\D", "", n_nodes))
      dag_type_folders <- list.dirs(n_nodes_folder, recursive = FALSE)
      for (dag_type_folder in dag_type_folders) {
        for (file_path in list.files(dag_type_folder, full.names = TRUE)) {
          if (grepl(".rds$", file_path)) {
            train_model(file_path, n_training_points)
          }
        }
      }
    }
  }
}

train_all_models()