#' @rdname setup_approach
#'
#' @param gaussian.mu Numeric vector. (Optional)
#' Containing the mean of the data generating distribution.
#' `NULL` means it is estimated from the `x_train`.
#'
#' @param gaussian.cov_mat Numeric matrix. (Optional)
#' Containing the covariance matrix of the data generating distribution.
#' `NULL` means it is estimated from the `x_train`.
#'
#' @inheritParams default_doc_explain
#'
#' @export
setup_approach.causal_chain <- function(internal,
                                    gaussian.mu = NULL,
                                    gaussian.cov_mat = NULL, ...) {
  # For consistency
  defaults <- mget(c("gaussian.mu", "gaussian.cov_mat"))
  internal <- insert_defaults(internal, defaults)

  x_train <- internal$data$x_train
  feature_specs <- internal$objects$feature_specs

  # Checking if factor features are present
  if (any(feature_specs$classes == "factor")) {
    factor_features <- names(which(feature_specs$classes == "factor"))
    factor_approaches <- get_factor_approaches()
    stop(paste0(
      "The following feature(s) are factor(s): ", factor_features, ".\n",
      "approach = 'causal_chain' does not support factor features.\n",
      "Please change approach to one of ", paste0(factor_approaches, collapse = ", "), "."
    ))
  }

  # If gaussian.mu is not provided directly in internal list, use mean of training data
  if (is.null(internal$parameters$gaussian.mu)) {
    internal$parameters$gaussian.mu <- get_mu_vec(x_train)
  }

  # If gaussian.cov_mat is not provided directly in internal list, use sample covariance of training data
  if (is.null(internal$parameters$gaussian.cov_mat)) {
    internal$parameters$gaussian.cov_mat <- get_cov_mat(x_train)
  }

  return(internal)
}

#' @rdname prepare_data
#' @export
prepare_data.causal_chain <- function(internal, index_features = NULL, ...) {
  x_train <- internal$data$x_train
  x_explain <- internal$data$x_explain
  n_explain <- internal$parameters$n_explain
  n_samples <- internal$parameters$n_samples
  n_features <- internal$parameters$n_features

  gaussian.cov_mat <- internal$parameters$gaussian.cov_mat
  gaussian.mu <- internal$parameters$gaussian.mu

  partial_causal_ordering <- internal$parameters$partial_causal_ordering
  confounding <- internal$parameters$confounding

  X <- internal$objects$X

  x_explain0 <- as.matrix(x_explain)
  dt_l <- list()

  if (is.null(index_features)) {
    features <- X$features
  } else {
    features <- X$features[index_features]
  }

  for (i in seq_len(n_explain)) {
    l <- lapply(
      X = features,
      FUN = sample_causal_chain,
      n_samples = n_samples,
      mu = gaussian.mu,
      cov_mat = gaussian.cov_mat,
      m = n_features,
      x_explain = x_explain0[i, , drop = FALSE],
      partial_causal_ordering = partial_causal_ordering,
      confounding = confounding
    )

    dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
    dt_l[[i]][, w := 1 / n_samples]
    dt_l[[i]][, id := i]
    if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
  }

  dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE)
  return(dt)
}



#' Sample from a causal chain graph assuming the variables follow a Gaussian distribution, code adapted from
#' Heskes et al.
#'
#' @inheritParams sample_gaussian
#'
#' @param partial_causal_ordering List of vectors. Each vector contains the indices of the variables in a chain
#' component, components are causally ordered.
#'
#' @param confounding Logical vector of length one (TRUE if all components are confounded) or length equal to the
#' number of components in the causal chain. If TRUE, the component is confounded.
#'
#' @return data.table
#'
#' @keywords internal
#'
sample_causal_chain <- function(index_given, n_samples, mu, cov_mat, m, x_explain,
                                partial_causal_ordering, confounding) {
  # Check input
  stopifnot(is.matrix(x_explain))
  stopifnot(is.list(partial_causal_ordering))

  if (length(confounding) > 1 && length(confounding) != length(partial_causal_ordering)) {
    stop("Confounding must be specified globally (one value for all components), or separately for each causal component in a vector.")
  }
  if (length(confounding) == 1) {
    confounding <- rep(confounding, length(partial_causal_ordering)) # replicate value for each component
  }

  # Handles the unconditional and full conditional separtely when predicting
  cnms <- colnames(x_explain)
  if (length(index_given) %in% c(0, m)) {
    return(data.table::as.data.table(x_explain))
  }

  dependent_ind <- seq_along(mu)[-index_given]
  x_explain_gaussian <- x_explain[index_given]

  # Initialize matrix of samples
  samples <- matrix(NA, ncol = m, nrow = n_samples)
  samples[, index_given] <- rep(x_explain_gaussian, each = n_samples)

  # Loop over chain components
  for(idx_chain_component in seq(length(partial_causal_ordering))) {
    component_variables <- partial_causal_ordering[[idx_chain_component]]
    is_confounded_component <- confounding[[idx_chain_component]]

    # Get variables that are both free (i.e. not being conditioned on) and in the current component
    to_be_sampled <- intersect(component_variables, dependent_ind)
    if (length(to_be_sampled) > 0) {
      # Retrieve indices of features in ancestor components (condition upon all variables in ancestor components)
      to_be_conditioned <- unlist(partial_causal_ordering[0:(idx_chain_component-1)])

      # (back to conditioning if confounding is FALSE or no conditioning if confounding is TRUE)
      # If this is not a confounded component, we add conditioning on features in the same component
      if (!is_confounded_component) {
        # add intervened variables in the same component
        to_be_conditioned <- union(intersect(is_confounded_component, index_given), to_be_conditioned)
      }
      if (length(to_be_conditioned) == 0) {
        # draw new samples from marginal distribution
        to_be_sampled_samples <- mvnfast::rmvn(n_samples, mu=mu[to_be_sampled], sigma=as.matrix(cov_mat[to_be_sampled,to_be_sampled]))
      }
      else {
        # condMVNorm does not support providing multiple conditioning values
        # conditional_gaussian <- condMVNorm::condMVN(
        #   mean = mu,
        #   sigma = cov_mat,
        #   dependent.ind = to_be_sampled,
        #   given.ind = to_be_conditioned,
        #   X.given = samples[,to_be_conditioned]
        # )
        #
        # # Makes the conditional covariance matrix symmetric in the rare case where numerical instability made it unsymmetric
        # if (!isSymmetric(conditional_gaussian[["condVar"]])) {
        #   conditional_gaussian[["condVar"]] <- Matrix::symmpart(conditional_gaussian$condVar)
        # }
        #
        # to_be_sampled_samples <- mvnfast::rmvn(n = n_samples, mu = conditional_gaussian$condMean,
        #                                        sigma = conditional_gaussian$condVar)
        # compute conditional Gaussian
        C <- cov_mat[to_be_sampled,to_be_conditioned, drop=FALSE]
        D <- cov_mat[to_be_conditioned, to_be_conditioned]
        CDinv <- C %*% solve(D)
        cVar <- cov_mat[to_be_sampled, to_be_sampled] - CDinv %*% t(C)
        if (!isSymmetric(cVar)) {
          cVar <- Matrix::symmpart(cVar)
        }
        # draw new samples from conditional distribution
        mu_sample <- matrix(rep(mu[to_be_sampled], each=n_samples), nrow=n_samples)
        mu_cond <- matrix(rep(mu[to_be_conditioned], each=n_samples), nrow=n_samples)
        cMU <- mu_sample + t(CDinv %*% t(samples[, to_be_conditioned] - mu_cond))
        newsamples <- mvnfast::rmvn(n_samples, mu=matrix(0,1,length(to_be_sampled)), sigma=as.matrix(cVar))
        to_be_sampled_samples <- newsamples + cMU
      }
      samples[, to_be_sampled] <- to_be_sampled_samples
    }
  }
  colnames(samples) <- cnms
  return(as.data.table(samples))
}
