## Script to estimate the causal Shapley value for a single specified DAG and model using either
## the 'sample' method or the 'IW' method.
## Since in our setup a model is tied to a dataset, we extract the data to explain from the model path.
library(optparse)

source('R/scm/graph_functions.R')
source('R/scm/node_functions.R')
source('R/scm/binary_scm.R')
source('R/scm/additive_noise_scm.R')
source('R/scm/linear_gaussian_scm.R')

source('R/experiments/helper_functions.R')

library(shapr)

compute_causal_shapley_estimate <- function(model_path, n_to_explain = 40, method = "sample", n_observational = 1024,
                                            n_samples_expectation = 1024, n_combinations = NULL,
                                            n_batches = NULL, cpdag_path = NULL, cpdag_idx = NULL) {
  # Default values for optional arguments
  if (is.null(n_to_explain)) {
    n_to_explain <- 40
  }
  if (is.null(method)) {
    method <- "sample"
  }

  if (is.null(n_observational)) {
    n_observational <- 1024
  }
  if (is.null(n_samples_expectation)) {
    n_samples_expectation <- 1024
  }

  # Read necessary files
  model_and_parameters <- readRDS(model_path)
  data_path <- model_and_parameters$data_path
  data_and_parameters <- readRDS(data_path)

  n_nodes <- model_and_parameters$n_nodes
  seed <- model_and_parameters$seed

  # If cpdag_path is not provided, we assume we are estimating the causal Shapley values for the true DAG
  if (is.null(cpdag_path)) {
    scm_and_parameters <- readRDS(data_and_parameters$scm_path)
    dag <- scm_and_parameters$scm$dag
  } else {
    if (is.null(cpdag_idx)) {
      stop("Please provide cpdag index.")
    }
    cpdag_and_parameters <- readRDS(cpdag_path)
    dag <- cpdag_and_parameters$dags[[cpdag_idx]]
  }

  # Retrieve default settings for n_combinations_shapr and n_batches if not specified
  if (is.null(n_combinations)) {
    n_combinations <- determine_n_combinations(n_nodes, n_samples_expectation)
  }
  if (is.null(n_batches)) {
    n_batches <- determine_n_batches(n_nodes, n_samples_expectation)
  }

  if (!is.null(cpdag_path)) {
    if (is.null(cpdag_idx)) {
      stop("Please provide cpdag index.")
    }
    log_to_shell(paste0("compute_causal_shapley_estimate",
                        ", cpdag_path = ", cpdag_path,
                        ", cpdag_idx = ", cpdag_idx,
                        ", model_path = ", model_path,
                        ", n_to_explain = ", n_to_explain,
                        ", method = ", method,
                        ", n_observational = ", n_observational,
                        ", n_samples_expectation =  ", n_samples_expectation,
                        ", n_combinations =  ", n_combinations,
                        ", n_batches = ", n_batches
                        ))
  } else {
    log_to_shell(paste0("compute_causal_shapley_estimate, true dag",
                        ", model_path = ", model_path,
                        ", n_to_explain = ", n_to_explain,
                        ", method = ", method,
                        ", n_observational =  ", n_observational,
                        ", n_samples_expectation =  ", n_samples_expectation,
                        ", n_combinations =  ", n_combinations,
                        ", n_batches = ", n_batches))
  }

  # Assert settings match
  stopifnot(model_and_parameters$seed == data_and_parameters$seed)
  stopifnot(model_and_parameters$n_nodes == data_and_parameters$n_nodes)
  stopifnot(model_and_parameters$n_expected_neighbors == data_and_parameters$n_expected_neighbors)
  stopifnot(model_and_parameters$n_expected_parents_y == data_and_parameters$n_expected_parents_y)
  stopifnot(model_and_parameters$scm_type == data_and_parameters$scm_type)


  # Get model and data
  # This method of loading does not work well with xgboost models
  # Make sure to use xgboost version 1.7.5.1
  # if (model_and_parameters$model_type == "xgboost") {
  #   # Get xgboost model from model path
  #   # We assume filename is of the form seed.rds
  #   dir <- dirname(model_path)
  #   xgb_filename <- sub("\\.rds$", "_xgb.model", basename(model_path))
  #   model_to_explain <- xgboost::xgb.load(paste0(dir, "/", xgb_filename))
  # } else {
  #   model_to_explain <- model_and_parameters$model
  # }
  model_to_explain <- model_and_parameters$model

  # Model can be a string to the xgboost model path
  if (is.character(model_to_explain)) {
    model_to_explain <- xgboost::xgb.load(model_to_explain)
  }

  data_observational <- data_and_parameters$data_observational
  data_to_explain <- data_and_parameters$data_to_explain

  data_to_explain <- data_to_explain[seq(n_to_explain), , drop = FALSE]
  x_to_explain <- data_to_explain[, seq(n_nodes - 1)] # Exclude y

  # Limit (if desired) amount of data available for estimation of Shapley values
  data_observational <- data_observational[seq(n_observational), ]
  x_observational <- data_observational[, seq(n_nodes - 1)]
  y_observational <- data_observational[, n_nodes]

  prediction_zero <- mean(y_observational)


  log_to_shell("Computing...")

  start_time <- Sys.time()

  if (model_and_parameters$scm_type == "binary") {
    explanation_causal <- shapr::explain(
      approach = "causal_binary",
      causal_approximation_method = method,
      causal_dag = dag,
      model = model_to_explain,
      x_train = x_observational,
      x_explain = x_to_explain,
      prediction_zero = prediction_zero,
      n_samples = n_samples_expectation,
      n_combinations = n_combinations,
      n_batches = n_batches,
      seed = seed
    )
  } else {
    explanation_causal <- shapr::explain(
      approach = "causal",
      causal_approximation_method = method,
      causal_dag = dag,
      model = model_to_explain,
      x_train = x_observational,
      x_explain = x_to_explain,
      prediction_zero = prediction_zero,
      n_samples = n_samples_expectation,
      n_combinations = n_combinations,
      n_batches = n_batches,
      seed = seed,
    )
  }
  end_time <- Sys.time()
  duration <- end_time - start_time

  log_to_shell("Done computing. Saving results...")

  if (method == "sample") {
    explanation_type <- "causal_sampled"
  } else if (method == "iw") {
    explanation_type <- "causal_iw"
  } else {
    stop("Unknown method")
  }

  save_explanation(explanation_causal,
                   explanation_type,
                   model_and_parameters,
                   model_path,
                   n_observational,
                   n_samples_expectation,
                   n_combinations,
                   n_batches,
                   duration,
                   cpdag_idx,
                   cpdag_path,
                   data_path = data_path,
                   dag = dag
                   )
  log_to_shell("Results saved.")
}

# Read from command line
opt_parser <- get_opt_parser()
opts <- optparse::parse_args(opt_parser)

if (is.null(opts$modelpath)) {
  stop("Please provide model path.")
}

compute_causal_shapley_estimate(opts$modelpath, opts$ntoexplain, opts$method, opts$nobs, opts$nmc,
                                opts$ncomb, opts$nbatches, opts$cpdagpath, opts$cpdagidx)

# print("DEBUG")
# compute_causal_shapley_estimate("experiments/4_models/xgboost/gaussian/6_nodes/2_neighbors_2_parents_y/1.rds", NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)


# args <- commandArgs(trailingOnly = TRUE)
# log_to_shell(paste0("shapley_causal_estimate.R ", paste(args, collapse = " ")))
#
# # Default settings
# # We either provide a path to the model (which links to the true dag), or a path to the cpdag and corresponding dag index
# if (length(args) == 3) {
#   model_path <- args[1]
#   n_observational <- as.numeric(args[2])
#   n_samples_expectation <- as.numeric(args[3])
#
#   model_and_parameters <- readRDS(model_path)
#   data_path <- model_and_parameters$data_path
#
#   data_and_parameters <- readRDS(data_path)
#   scm_path <- data_and_parameters$scm_path
#
#   scm_and_parameters <- readRDS(scm_path)
#
#   dag <- scm_and_parameters$scm$dag
#
#   compute_causal_shapley_estimate(dag, model_path, data_path, n_observational, approximation_method, true_dag_path = true_dag_path)
# }
# if (length(args) == 4) {
#   cpdag_path <- args[1]
#   cpdag_idx <- as.numeric(args[2])
#   n_observational <- as.numeric(args[3])
#   n_samples_expectation <- as.numeric(args[4])
#
#   cpdag_and_parameters <- readRDS(cpdag_path)
#
#   # Extract parameters to get model and data paths
#   scm_type <- cpdag_and_parameters$scm_type
#   n_nodes <- cpdag_and_parameters$n_nodes
#   n_expected_neighbors <- cpdag_and_parameters$n_expected_neighbors
#   n_expected_parents_y <- cpdag_and_parameters$n_expected_parents_y
#   seed <- cpdag_and_parameters$seed
#
#   model_path <- paste0("experiments/4_models/xgboost/", scm_type, "/", n_nodes, "_nodes/", n_expected_neighbors, "_neighbors_", n_expected_parents_y, "_parents_y/", seed, ".rds")
#   #data_path <- paste0("experiments/3_data/", scm_type, "/", n_nodes, "_nodes/", n_expected_neighbors, "_neighbors_", n_expected_parents_y, "_parents_y/", seed, ".rds")
#
#   # Extract dag
#   dag <- cpdag_and_parameters$dags[[cpdag_idx]]
#
#   compute_causal_shapley_estimate(dag, model_path, n_observational, n_samples_expectation = n_samples_expectation, cpdag_path = cpdag_path, cpdag_idx = cpdag_idx)
# }
#
# # Custom settings
# if (length(args) == 6) {
#   model_path <- args[1]
#   n_observational <- as.numeric(args[2])
#   approximation_method <- args[3]
#   n_samples_expectation <- as.numeric(args[4])
#   n_combinations_shapr <- as.numeric(args[5])
#   n_batches <- as.numeric(args[6])
#
#   if (n_combinations_shapr == 0) {
#     n_combinations_shapr <- NULL
#   }
#   if (n_batches == 0) {
#     n_batches <- NULL
#   }
#
#   model_and_parameters <- readRDS(model_path)
#   data_path <- model_and_parameters$data_dir
#
#   data_and_parameters <- readRDS(data_path)
#   scm_path <- data_and_parameters$scm_dir
#
#   scm_and_parameters <- readRDS(scm_path)
#   true_dag_path <- scm_and_parameters$dag_path
#
#   dag <- scm_and_parameters$scm$dag
#
#   compute_causal_shapley_estimate(dag, model_path, data_path, n_observational, approximation_method, n_samples_expectation, n_combinations_shapr, n_batches, true_dag_path = true_dag_path)
# }
# if (length(args) == 7) {
#   cpdag_path <- args[1]
#   cpdag_idx <- as.numeric(args[2])
#   n_observational <- as.numeric(args[3])
#   approximation_method <- args[4]
#   n_samples_expectation <- as.numeric(args[5])
#   n_combinations_shapr <- as.numeric(args[6])
#   n_batches <- as.numeric(args[7])
#
#   if (n_combinations_shapr == 0) {
#     n_combinations_shapr <- NULL
#   }
#   if (n_batches == 0) {
#     n_batches <- NULL
#   }
#
#   cpdag_and_parameters <- readRDS(cpdag_path)
#
#   # Extract parameters to get model and data paths
#   scm_type <- cpdag_and_parameters$scm_type
#   n_nodes <- cpdag_and_parameters$n_nodes
#   n_expected_neighbors <- cpdag_and_parameters$n_expected_neighbors
#   n_expected_parents_y <- cpdag_and_parameters$n_expected_parents_y
#   seed <- cpdag_and_parameters$seed
#
#   model_path <- paste0("experiments/4_models/", scm_type, "/", n_nodes, "_nodes/", n_expected_neighbors, "_neighbors_", n_expected_parents_y, "_parents_y/", seed, ".rds")
#   data_path <- paste0("experiments/3_data/", scm_type, "/", n_nodes, "_nodes/", n_expected_neighbors, "_neighbors_", n_expected_parents_y, "_parents_y/", seed, ".rds")
#
#   # Extract dag
#   dag <- cpdag_and_parameters$dags[[cpdag_idx]]
#
#   compute_causal_shapley_estimate(dag, model_path, data_path, n_observational, approximation_method, n_samples_expectation, n_combinations_shapr, n_batches, cpdag_path = cpdag_path, cpdag_idx = cpdag_idx)
# }