#' @rdname setup_approach
#'
#' @param gaussian.mu Numeric vector. (Optional)
#' Containing the mean of the data generating distribution.
#' `NULL` means it is estimated from the `x_train`.
#'
#' @param gaussian.cov_mat Numeric matrix. (Optional)
#' Containing the covariance matrix of the data generating distribution.
#' `NULL` means it is estimated from the `x_train`.
#'
#' @inheritParams default_doc_explain
#'
#' @export
setup_approach.causal_true <- function(internal,
                                   ...) {
  if (is.null(internal$parameters$scm)) {
    stop("scm must be accessible to shapr")
  }

  # We require this because the columns of the data determine the interventions
  vertex_names <- igraph::V(internal$parameters$scm$dag)$name
  if (!(length(vertex_names) == (length(colnames(internal$data$x_train)))+1)) {
    stop("The DAG does not contain N+1 nodes where N is the number of columns of x_train.")
  }

  if (!identical(intersect(vertex_names, colnames(internal$data$x_train)), colnames(internal$data$x_train))) {
    stop("The order of the vertices in the causal DAG is not equal to the order of the data.")
  }

  return(internal)
}

#' @rdname prepare_data
#' @export
prepare_data.causal_true <- function(internal, index_features = NULL, ...) {
  x_explain <- internal$data$x_explain
  n_explain <- internal$parameters$n_explain
  n_samples <- internal$parameters$n_samples

  X <- internal$objects$X

  generate_data_under_intervention <- internal$parameters$generate_data_under_intervention

  x_explain0 <- as.matrix(x_explain)
  dt_l <- list()

  if (is.null(index_features)) {
    features <- X$features
  } else {
    features <- X$features[index_features]
  }

  sample_true_interventional_distribution <- function(n_samples, scm, index_features, x_explain, generate_data_under_intervention) {
    # Appropriate format for intervention
    intervention <- x_explain[index_features]
    names(intervention) <- index_features
    interventional_sample <- generate_data_under_intervention(scm, n_samples, intervention)

    # Remove column of predictive variable; since this is only used for our experiments we can assume this column
    # is named Y
    idx_y <- which(colnames(interventional_sample) == 'Y')
    return(as.data.table(interventional_sample[, -idx_y]))
  }

  for (i in seq_len(n_explain)) {
    l <- lapply(
      X = features,
      FUN = sample_true_interventional_distribution,
      n_samples = n_samples,
      scm = internal$parameters$scm,
      x_explain = x_explain0[i, , drop = FALSE],
      generate_data_under_intervention = generate_data_under_intervention

    )

    dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
    dt_l[[i]][, w := 1 / n_samples]
    dt_l[[i]][, id := i]
    if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
  }

  dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE)
  return(dt)
}

