#' @rdname setup_approach
#'
#' @param gaussian.mu Numeric vector. (Optional)
#' Containing the mean of the data generating distribution.
#' `NULL` means it is estimated from the `x_train`.
#'
#' @param gaussian.cov_mat Numeric matrix. (Optional)
#' Containing the covariance matrix of the data generating distribution.
#' `NULL` means it is estimated from the `x_train`.
#'
#' @inheritParams default_doc_explain
#'
#' @export
setup_approach.causal <- function(internal,
                                    gaussian.mu = NULL,
                                    gaussian.cov_mat = NULL, ...) {
  # For consistency
  defaults <- mget(c("gaussian.mu", "gaussian.cov_mat"))
  internal <- insert_defaults(internal, defaults)

  x_train <- internal$data$x_train
  feature_specs <- internal$objects$feature_specs

  # Check if order of vertices is equal to order of data
  if (!identical(igraph::V(internal$parameters$causal_dag)$name, colnames(x_train))) {
    stop("The order of the vertices in the causal DAG is not equal to the order of the data.")
  }

  # Checking if factor features are present
  if (any(feature_specs$classes == "factor")) {
    factor_features <- names(which(feature_specs$classes == "factor"))
    factor_approaches <- get_factor_approaches()
    stop(paste0(
      "The following feature(s) are factor(s): ", factor_features, ".\n",
      "approach = 'causal' does not support factor features.\n",
      "Please change approach to one of ", paste0(factor_approaches, collapse = ", "), "."
    ))
  }

  # Prepare parents list and topological order
  # Which will be re-used for each sample
  if (is.null(internal$parameters$parents_list)) {
    internal$parameters$parents_list <- lapply(igraph::V(internal$parameters$causal_dag), function(node) {
      igraph::neighbors(internal$parameters$causal_dag, node, mode = "in")
    })
  }

  if (is.null(internal$parameters$topo_order)) {
    internal$parameters$topo_order <- igraph::topo_sort(internal$parameters$causal_dag)
  }

  # If gaussian.mu is not provided directly in internal list, use mean of training data
  if (is.null(internal$parameters$gaussian.mu)) {
    internal$parameters$gaussian.mu <- get_mu_vec(x_train)
  }

  # If gaussian.cov_mat is not provided directly in internal list, use sample covariance of training data
  if (is.null(internal$parameters$gaussian.cov_mat)) {
    internal$parameters$gaussian.cov_mat <- get_cov_mat(x_train)
  }

  if (is.null(internal$parameters$causal_approximation_method)) {
    internal$parameters$causal_approximation_method <- "sample"
  }

  return(internal)
}

#' @rdname prepare_data
#' @export
prepare_data.causal <- function(internal, index_features = NULL, ...) {
  x_train <- internal$data$x_train
  x_explain <- internal$data$x_explain
  n_explain <- internal$parameters$n_explain
  n_samples <- internal$parameters$n_samples
  n_features <- internal$parameters$n_features

  gaussian.cov_mat <- internal$parameters$gaussian.cov_mat
  gaussian.mu <- internal$parameters$gaussian.mu

  # Causal-specific parameters
  causal_dag <- internal$parameters$causal_dag
  parents_list <- internal$parameters$parents_list
  topo_order <- internal$parameters$topo_order

  # Approximation method
  approximation_method <- internal$parameters$causal_approximation_method

  X <- internal$objects$X

  x_explain0 <- as.matrix(x_explain)
  dt_l <- list()

  if (is.null(index_features)) {
    features <- X$features
  } else {
    features <- X$features[index_features]
  }

  if (approximation_method == "sample") {
    for (i in seq_len(n_explain)) {
      l <- lapply(
        X = features,
        FUN = sample_causal_gaussian,
        n_samples = n_samples,
        mu = gaussian.mu,
        cov_mat = gaussian.cov_mat,
        m = n_features,
        x_explain = x_explain0[i, , drop = FALSE],
        causal_dag = causal_dag,
        parents_list = parents_list,
        topo_order = topo_order
      )

      dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
      dt_l[[i]][, w := 1 / n_samples]
      dt_l[[i]][, id := i]
      if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
    }
  } else if (approximation_method == "iw") {

    for (i in seq_len(n_explain)) {
      l <- lapply(
        X = features,
        FUN = augment_and_weight_gaussian,
        training_data = x_train,
        x_to_explain = x_explain0[i, , drop = FALSE],
        mu = gaussian.mu,
        cov_mat = gaussian.cov_mat,
        parents_list = parents_list,
        dag = causal_dag
      )

      dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
      dt_l[[i]][, id := i]
      if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
    }
  }

  dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE)
  return(dt)
}



#' Sample from the interventional distribution given a graph, assuming the variables follow a Gaussian distribution,
#' code adapted from Heskes et al.
#'
#' @inheritParams sample_gaussian
#'
#' @param causal_dag An igraph DAG that represents the underlying data generation process; the order of the vertices
#' should be equal to the column order of the data (this is checked in setup).
#'
#' @return data.table
#'
#' @keywords internal
sample_causal_gaussian <- function(index_given, n_samples, mu, cov_mat, m, x_explain,
                                   causal_dag, parents_list, topo_order) {
  # Check input
  stopifnot(is.matrix(x_explain))
  stopifnot(igraph::is.dag(causal_dag))


  # Handles the unconditional and full conditional separtely when predicting
  cnms <- colnames(x_explain)
  if (length(index_given) %in% c(0, m)) {
    return(data.table::as.data.table(x_explain))
  }

  dependent_ind <- seq_along(mu)[-index_given]  # indices of 'marginal' variables; variables we are averaging over
  x_explain_causal <- x_explain[index_given]

  # Initialize matrix of samples
  samples <- matrix(NA, ncol = m, nrow = n_samples)
  samples[, index_given] <- matrix(x_explain_causal, nrow = n_samples, ncol = length(index_given), byrow = TRUE)

  for (idx_node in topo_order) {
    if (idx_node %in% index_given) {
      # If the node is intervened on, we do not need to sample from it
      # todo: remove before using intersection() or something
      next
    }

    node_parents <- parents_list[[idx_node]]

    # Sample from marginal if node has no parents
    if (length(node_parents) == 0) {
      samples[, idx_node] <- mvnfast::rmvn(n_samples, mu=mu[idx_node], sigma=as.matrix(cov_mat[idx_node,idx_node]))
    }
    else {
      # Sample from conditional distribution
      to_be_sampled <- idx_node # for consistency
      to_be_conditioned <- as.numeric(node_parents)

      C <- cov_mat[to_be_sampled, to_be_conditioned, drop=FALSE]
      D <- cov_mat[to_be_conditioned, to_be_conditioned]
      CDinv <- C %*% solve(D)
      cVar <- cov_mat[to_be_sampled, to_be_sampled] - CDinv %*% t(C)

      # This is a slow function
      # if (!isSymmetric(cVar)) {
      #   cVar <- Matrix::symmpart(cVar)
      # }

      mu_sample <- matrix(mu[to_be_sampled], nrow = n_samples, ncol = length(to_be_sampled), byrow = TRUE)
      mu_cond <- matrix(mu[to_be_conditioned], nrow = n_samples, ncol = length(to_be_conditioned), byrow = TRUE)
      cMU <- mu_sample + t(CDinv %*% t(samples[, to_be_conditioned] - mu_cond))
      newsamples <- mvnfast::rmvn(n_samples, mu=matrix(0,1,length(to_be_sampled)), sigma=as.matrix(cVar))
      samples[, idx_node] <- newsamples + cMU
    }

  }
  colnames(samples) <- cnms
  return(as.data.table(samples))
}

augment_and_weight_gaussian <- function(training_data, x_to_explain, index_given, mu, cov_mat, parents_list, dag) {
  augmented_data <- augment_data(training_data, x_to_explain, index_given)
  weights <- calculate_importance_weights_gaussian(augmented_data, index_given, mu, cov_mat, parents_list)

  #intervention <- x_to_explain[index_given]
  #names(intervention) <- as.numeric(igraph::V(dag)[index_given])

  augmented_data[, w := weights]

  return(augmented_data)
}

augment_data <- function(training_data, x_to_explain, index_given) {
  # Replace columns of training data with columns of x_to_explain for the given indices
  augmented_data <- data.table::copy(training_data)
  #augmented_data[[index_given]] <- matrix(x_to_explain[, index_given, drop=FALSE], nrow = nrow(augmented_data), ncol = length(index_given), byrow = TRUE)

  # Replace column values of columns in index_given with x_to_explain
  # for (i in seq_along(index_given)) {
  #   data.table::set(augmented_data, j = index_given[i], value = x_to_explain[i])
  # }

  for (i in seq_along(index_given)) {
    augmented_data[, (index_given[i]) := as.data.table(matrix(x_to_explain[index_given[i]], nrow = nrow(augmented_data), ncol = 1, byrow = TRUE))]
  }

  return(augmented_data)
}

calculate_importance_weights_gaussian <- function(augmented_data, index_given, mu, cov_mat,
                                                  parents_list) {
  #print(index_given)
  n_samples <- nrow(augmented_data)

  # Log p(x_notS|do(S)) = sum_{j in notS} log p(x_j|pa_j)
  log_numerator <- vector("numeric", length = nrow(augmented_data))
  for (idx_node in seq_along(mu)) {
    # Skip nodes that are intervened on
    if (idx_node %in% index_given) {
      next
    }

    node_parents <- parents_list[[idx_node]]

    # Log density of marginal if node has no parents
    if (length(node_parents) == 0) {
      log_numerator <- log_numerator + mvnfast::dmvn(as.matrix(augmented_data[[idx_node]]),
                                                         mu=mu[idx_node],
                                                         sigma=as.matrix(cov_mat[idx_node,idx_node]),
                                                         log=TRUE)
    }
    else {

      to_be_sampled <- idx_node # for consistency
      to_be_conditioned <- as.numeric(node_parents)

      # Compute conditional covariance matrix
      C <- cov_mat[to_be_sampled, to_be_conditioned, drop=FALSE]
      D <- cov_mat[to_be_conditioned, to_be_conditioned]
      CDinv <- C %*% solve(D)
      cVar <- cov_mat[to_be_sampled, to_be_sampled] - CDinv %*% t(C)

      mu_sample <- matrix(mu[to_be_sampled], nrow = n_samples, ncol = length(to_be_sampled), byrow = TRUE)
      mu_cond <- matrix(mu[to_be_conditioned], nrow = n_samples, ncol = length(to_be_conditioned), byrow = TRUE)
      cMU <- mu_sample + t(CDinv %*% t(augmented_data[, ..to_be_conditioned] - mu_cond))

      mu_subtracted_data <- augmented_data[[idx_node]] - cMU

      log_numerator <- log_numerator + mvnfast::dmvn(mu_subtracted_data,
                                                         mu=matrix(0,1,length(to_be_sampled)),
                                                         sigma=as.matrix(cVar),
                                                         log=TRUE)
    }
  }

  # log p(X_notS)
    log_denominator <- mvnfast::dmvn(as.matrix(augmented_data[, !index_given, with = FALSE]),
                                        mu=mu[-index_given],
                                        sigma=as.matrix(cov_mat[-index_given, -index_given]),
                                        log=TRUE)
  weights <- exp(log_numerator - log_denominator)
  return(weights)
  #return(exp(log_numerator - log_denominator))
}