## Script to estimate the true causal Shapley values. A model_path contains a link to the data
## we explain and the corresponding SCM. The SCM is used to sample from the true interventional
## distributions.

source('R/scm/graph_functions.R')
source('R/scm/node_functions.R')
source('R/scm/binary_scm.R')
source('R/scm/additive_noise_scm.R')
source('R/scm/linear_gaussian_scm.R')

source('R/experiments/helper_functions.R')

library(shapr)
library(future)

compute_causal_true_shapley_parallel <- function(model_path, n_to_explain = 40, n_samples_expectation = 4096,
                                        n_combinations = NULL, n_batches = NULL) {

   # Default values for optional arguments
  if (is.null(n_to_explain)) {
    n_to_explain <- 40
  }
  if (is.null(n_samples_expectation)) {
    n_samples_expectation <- 4096
  }

  log_to_shell(paste0("compute_causal_true_shapley_parallel",
                      ", model_path = ", model_path,
                      ", n_to_explain = ", n_to_explain,
                      ", n_samples_expectation =  ", n_samples_expectation,
                      ", n_combinations =  ", n_combinations,
                      ", n_batches = ", n_batches))

  # Read necessary files
  model_and_parameters <- readRDS(model_path)
  data_path <- model_and_parameters$data_path
  data_and_parameters <- readRDS(data_path)
  scm_and_parameters <- readRDS(data_and_parameters$scm_path)

  # Assert settings match
  stopifnot(model_and_parameters$seed == data_and_parameters$seed)
  stopifnot(model_and_parameters$n_nodes == data_and_parameters$n_nodes)
  stopifnot(model_and_parameters$n_expected_neighbors == data_and_parameters$n_expected_neighbors)
  stopifnot(model_and_parameters$n_expected_parents_y == data_and_parameters$n_expected_parents_y)
  stopifnot(model_and_parameters$scm_type == data_and_parameters$scm_type)

  # Get model and data
  # This method of loading does not work well with xgboost models
  # Make sure to use xgboost version 1.7.5.1
  # if (model_and_parameters$model_type == "xgboost") {
  #   # Get xgboost model from model path
  #   # We assume filename is of the form seed.rds
  #   dir <- dirname(model_path)
  #   xgb_filename <- sub("\\.rds$", "_xgb.model", basename(model_path))
  #   model_to_explain <- xgboost::xgb.load(paste0(dir, "/", xgb_filename))
  # } else {
  #   model_to_explain <- model_and_parameters$model
  # }
  model_to_explain <- model_and_parameters$model

  # Model can be a string to the xgboost model path
  if (is.character(model_to_explain)) {
    model_to_explain <- xgboost::xgb.load(model_to_explain)
  }

  data_observational <- data_and_parameters$data_observational
  data_to_explain <- data_and_parameters$data_to_explain

  n_nodes <- model_and_parameters$n_nodes
  seed <- model_and_parameters$seed
  scm <- scm_and_parameters$scm

  data_to_explain <- data_to_explain[seq(n_to_explain), , drop = FALSE]
  x_to_explain <- data_to_explain[, seq(n_nodes - 1), drop = FALSE]

  # Get true prediction_zero by taking a large number of observational samples
  y_observational <- data_observational[, n_nodes]
  prediction_zero <- mean(y_observational)

  x_observational <- data_observational[seq(2), seq(n_nodes - 1)] # Not used by causal_true but matrix has to be passed

  # Retrieve default settings
  n_nodes <- data_and_parameters$n_nodes
  if (is.null(n_combinations)) {
    n_combinations <- determine_n_combinations(n_nodes, n_samples_expectation)
  }
  if (is.null(n_batches)) {
    n_batches <- determine_n_batches(n_nodes, n_samples_expectation)
  }

  log_to_shell("Initializing 8 workers...")
  future::plan(multisession, workers = 8)

  log_to_shell("Computing...")

  start_time <- Sys.time()

  explanation_causal_true <- shapr::explain(
      approach = "causal_true",
      model = model_to_explain,
      x_train = x_observational,
      x_explain = x_to_explain,
      prediction_zero = prediction_zero,
      n_samples = n_samples_expectation,
      n_combinations = n_combinations,
      n_batches = n_batches,
      seed = seed,
      scm = scm
    )

  end_time <- Sys.time()
  duration <- end_time - start_time
  future::plan(sequential)
  log_to_shell("Done computing. Saving results...")

  save_explanation(explanation_causal_true,
                   "causal_true",
                   model_and_parameters,
                   model_path,
                   NULL,
                   n_samples_expectation,
                   n_combinations,
                   n_batches,
                   duration)

  log_to_shell("Results saved.")

}

# Read from command line
opt_parser <- get_opt_parser()
opts <- optparse::parse_args(opt_parser)

if (is.null(opts$modelpath)) {
  stop("Please provide model path.")
}

compute_causal_true_shapley_parallel(opts$modelpath, opts$ntoexplain, opts$nmc,
                                opts$ncomb, opts$nbatches)



