 #' @rdname setup_approach
#'
#' @param gaussian.mu Numeric vector. (Optional)
#' Containing the mean of the data generating distribution.
#' `NULL` means it is estimated from the `x_train`.
#'
#' @param gaussian.cov_mat Numeric matrix. (Optional)
#' Containing the covariance matrix of the data generating distribution.
#' `NULL` means it is estimated from the `x_train`.
#'
#' @inheritParams default_doc_explain
#'
#' @export
setup_approach.mec <- function(internal,
                               gaussian.mu = NULL,
                               gaussian.cov_mat = NULL, ...) {
  # For consistency
  defaults <- mget(c("gaussian.mu", "gaussian.cov_mat"))
  internal <- insert_defaults(internal, defaults)

  x_train <- internal$data$x_train
  feature_specs <- internal$objects$feature_specs

  # Check if order of vertices is equal to order of data
  for (dag in internal$parameters$dags) {
    if (!identical(igraph::V(dag)$name, colnames(x_train))) {
      stop("The order of the vertices in a DAG in the MEC is not equal to the order of the data.")
    }
  }

  # Checking if factor features are present
  if (any(feature_specs$classes == "factor")) {
    factor_features <- names(which(feature_specs$classes == "factor"))
    factor_approaches <- get_factor_approaches()
    stop(paste0(
      "The following feature(s) are factor(s): ", factor_features, ".\n",
      "approach = 'mec' does not support factor features.\n",
      "Please change approach to one of ", paste0(factor_approaches, collapse = ", "), "."
    ))
  }

  # List of parents for each of the dags
  internal$parameters$parents_list_dags <- lapply(internal$parameters$dags, function(dag) {
      lapply(igraph::V(dag), function(node) {
        igraph::neighbors(dag, node, mode = "in")
      })
  })


  # Topological order for each of the dags
  internal$parameters$topological_orders_dags <- lapply(internal$parameters$dags, function(x) igraph::topo_sort(x))

  # Determine for each combination the unique interventional distributions; and keep track which DAG belongs to which
  # interventional distribution
  dags <- internal$parameters$dags
  n_dags <- length(internal$parameters$dags)

  X <- internal$objects$X
  combination_idxs <- X[, id_combination] #which(X[, n_features] != 0) # All combinations except the empty one
  equivalence_map <- matrix(data = 1, nrow = length(combination_idxs), ncol = length(dags)) # Keep track of which dags have equivalent interventional distributions

  # Loop over all combinations (except the empty one)
  # TODO: make more efficient
  # for (i in seq_along(combination_idxs)) {
  #   combination_idx <- combination_idxs[i]
  #
  #   # if (combination_idx == 3) {
  #   #   print('test')
  #   # }
  #
  #   # Manually skip the empty set
  #   if (combination_idx == 1) {
  #
  #   } else {
  #     intervened_feature_idxs <- X$features[[combination_idx]]
  #     intervened_feature_names <- internal$parameters$feature_names[intervened_feature_idxs]
  #     not_intervened_feature_idxs <- setdiff(seq_len(length(internal$parameters$feature_names)), intervened_feature_idxs)
  #     not_intervened_feature_names <- internal$parameters$feature_names[not_intervened_feature_idxs]
  #
  #     hashmap_formulas <- new.env()
  #     for (dag_idx in seq_len(length(dags))) {
  #         current_dag <- dags[[dag_idx]]
  #
  #         interventional_formula <- causaleffect::causal.effect(y = not_intervened_feature_names,
  #                                                             x = intervened_feature_names,
  #                                                             G = current_dag,
  #                                                             simp = TRUE,
  #                                                             prune = TRUE)
  #         ordered_formula <- order_formula(interventional_formula)
  #         if (exists(ordered_formula, envir = hashmap_formulas)) {
  #           equivalence_map[i, dag_idx] <- hashmap_formulas[[ordered_formula]]
  #         } else {
  #           hashmap_formulas[[ordered_formula]] <- dag_idx
  #           equivalence_map[i, dag_idx] <- dag_idx
  #         }
  #     }
  #   }
  # }


  results <- future.apply::future_lapply(seq_along(combination_idxs), function(i) {
    combination_idx <- combination_idxs[i]

    # This is the combination with no interventions
    if (combination_idx == 1) {
        return(rep(1, length(dags)))  # Skip certain computations conditionally
    } else {
        # Example computation
        intervened_feature_idxs <- X$features[[combination_idx]]
        intervened_feature_names <- internal$parameters$feature_names[intervened_feature_idxs]
        not_intervened_feature_idxs <- setdiff(seq_len(length(internal$parameters$feature_names)), intervened_feature_idxs)
        not_intervened_feature_names <- internal$parameters$feature_names[not_intervened_feature_idxs]

        hashmap_formulas <- new.env()
        dag_results <- vector("numeric", length(dags))

        for (dag_idx in seq_len(length(dags))) {
          current_dag <- dags[[dag_idx]]

          interventional_formula <- causaleffect::causal.effect(y = not_intervened_feature_names,
                                                              x = intervened_feature_names,
                                                              G = current_dag,
                                                              simp = TRUE,
                                                              prune = TRUE)
          ordered_formula <- order_formula(interventional_formula)
          if (exists(ordered_formula, envir = hashmap_formulas)) {
            dag_results[[dag_idx]] <- hashmap_formulas[[ordered_formula]]
          } else {
            hashmap_formulas[[ordered_formula]] <- dag_idx
            dag_results[[dag_idx]] <- dag_idx
          }
      }
        return(dag_results)
    }
  }, future.seed = TRUE)
  equivalence_map <- do.call(rbind, results)

  internal$parameters$intervention_equivalence_map <- equivalence_map

  # If data type (binary/continuous) has not been specified, assume continuous
  if (is.null(internal$parameters$data_type)) {
    internal$parameters$data_type <- "continuous"
  }

  # Data-specific set-up
  if (internal$parameters$data_type == "continuous") {
    # If gaussian.mu is not provided directly in internal list, use mean of training data
    if (is.null(internal$parameters$gaussian.mu)) {
      internal$parameters$gaussian.mu <- get_mu_vec(x_train)
    }
    # If gaussian.cov_mat is not provided directly in internal list, use sample covariance of training data
    if (is.null(internal$parameters$gaussian.cov_mat)) {
      internal$parameters$gaussian.cov_mat <- get_cov_mat(x_train)
    }
  } else if (internal$parameters$data_type == "binary") {
    internal$parameters$topo_order <- internal$parameters$topological_orders_dags[[1]] # Not used but we need something there to prevent us from calculating one again
    internal <- setup_approach.causal_binary(internal)
  }

  if (is.null(internal$parameters$causal_approximation_method)) {
    internal$parameters$causal_approximation_method <- "sample"
  }

  return(internal)
}



#' @rdname prepare_data
#' @export
prepare_data.mec <- function(internal, index_features = NULL, current_dag_idx = NULL, ...) {
  x_train <- internal$data$x_train
  x_explain <- internal$data$x_explain
  n_explain <- internal$parameters$n_explain
  n_samples <- internal$parameters$n_samples
  n_features <- internal$parameters$n_features

  if (internal$parameters$data_type == "continuous") {
    gaussian.cov_mat <- internal$parameters$gaussian.cov_mat
    gaussian.mu <- internal$parameters$gaussian.mu
  } else if (internal$parameters$data_type == "binary") {
    joint_probability_dt <- internal$parameters$binary.joint_prob_dt
  }

  # Get the current DAG and related information
  #current_dag_idx <- internal$parameters$current_dag_idx # Works but not the most elegant solution
  # Hack to deal with future
  if (is.list(current_dag_idx)) {
    current_dag_idx <- current_dag_idx[[1]]
  }


  current_dag <- internal$parameters$dags[[current_dag_idx]]
  parents_list <- internal$parameters$parents_list_dags[[current_dag_idx]]
  topo_order <- internal$parameters$topological_orders_dags[[current_dag_idx]]
  use_iw <- internal$parameters$causal_approximation_method == "iw"

  X <- internal$objects$X
  feature_names <- internal$parameters$feature_names

  x_explain0 <- as.matrix(x_explain)
  dt_l <- list()

  if (is.null(index_features)) {
    features <- X$features
  } else {
    features <- X$features[index_features]
  }

  for (i in seq_len(n_explain)) {
    if (internal$parameters$data_type == "continuous") {
      if (!use_iw) {
        l <- lapply(
          X = features,
          FUN = sample_causal,
          n_samples = n_samples,
          mu = gaussian.mu,
          cov_mat = gaussian.cov_mat,
          m = n_features,
          x_explain = x_explain0[i, , drop = FALSE],
          causal_dag = current_dag,
          parents_list = parents_list,
          topo_order = topo_order
        )
        dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
        dt_l[[i]][, w := 1 / n_samples]
      } else {
        l <- lapply(
          X = features,
          FUN = augment_and_weight_gaussian,
          training_data = x_train,
          x_to_explain = x_explain0[i, , drop = FALSE],
          mu = gaussian.mu,
          cov_mat = gaussian.cov_mat,
          parents_list = parents_list,
          dag = causal_dag
        )
        dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
      }

    } else if (internal$parameters$data_type == "binary") {
      if (!use_iw) {
        l <- lapply(
          X = features,
          FUN = sample_causal_binary,
          n_samples = n_samples,
          joint_probability_dt = joint_probability_dt,
          x_explain = x_explain0[i, , drop = FALSE],
          parents_list = parents_list,
          topo_order = topo_order,
          data = x_train
        )
        dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
        dt_l[[i]][, w := 1 / n_samples]
      } else {
        l <- lapply(
          X = features,
          FUN = augment_and_weight_binary,
          training_data = x_train,
          x_to_explain = x_explain0[i, , drop = FALSE],
          joint_probability_dt = joint_probability_dt,
          parents_list = parents_list,
          dag = causal_dag
        )
        dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
      }

    }



    dt_l[[i]][, id := i]

    if (internal$parameters$data_type == "binary") {
      # Combine identical rows to save computation
      dt_l[[i]] <- dt_l[[i]][, .(w = sum(w)), by = c("id_combination", feature_names, "id")]
    }


    if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
  }

  dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE)
  return(dt)
}



#' Sample from the interventional distribution given a graph, assuming the variables follow a Gaussian distribution,
#' code adapted from Heskes et al.
#'
#' @inheritParams sample_gaussian
#'
#' @param causal_dag An igraph DAG that represents the underlying data generation process; the order of the vertices
#' should be equal to the column order of the data (this is checked in setup).
#'
#' @return data.table
#'
#' @keywords internal
sample_causal <- function(index_given, n_samples, mu, cov_mat, m, x_explain,
                                causal_dag, parents_list, topo_order) {
  # Check input
  stopifnot(is.matrix(x_explain))
  stopifnot(igraph::is.dag(causal_dag))


  # Handles the unconditional and full conditional separately when predicting
  cnms <- colnames(x_explain)
  if (length(index_given) %in% c(0, m)) {
    return(data.table::as.data.table(x_explain))
  }

  dependent_ind <- seq_along(mu)[-index_given]  # indices of 'marginal' variables; variables we are averaging over
  x_explain_causal <- x_explain[index_given]

  # Initialize matrix of samples
  samples <- matrix(NA, ncol = m, nrow = n_samples)
  samples[, index_given] <- matrix(x_explain_causal, nrow = n_samples, ncol = length(index_given), byrow = TRUE)

  for (idx_node in topo_order) {
    if (idx_node %in% index_given) {
      # If the node is intervened on, we do not need to sample from it
      # todo: remove before using intersection() or something
      next
    }

    node_parents <- parents_list[[idx_node]]

    # Sample from marginal if node has no parents
    if (length(node_parents) == 0) {
      samples[, idx_node] <- mvnfast::rmvn(n_samples, mu=mu[idx_node], sigma=as.matrix(cov_mat[idx_node,idx_node]))
    }
    else {
      # Sample from conditional distribution
      to_be_sampled <- idx_node # for consistency
      to_be_conditioned <- as.numeric(node_parents)

      C <- cov_mat[to_be_sampled, to_be_conditioned, drop=FALSE]
      D <- cov_mat[to_be_conditioned, to_be_conditioned]
      CDinv <- C %*% solve(D)
      cVar <- cov_mat[to_be_sampled, to_be_sampled] - CDinv %*% t(C)

      # This is a slow function
      # if (!isSymmetric(cVar)) {
      #   cVar <- Matrix::symmpart(cVar)
      # }

      mu_sample <- matrix(mu[to_be_sampled], nrow = n_samples, ncol = length(to_be_sampled), byrow = TRUE)
      mu_cond <- matrix(mu[to_be_conditioned], nrow = n_samples, ncol = length(to_be_conditioned), byrow = TRUE)
      cMU <- mu_sample + t(CDinv %*% t(samples[, to_be_conditioned] - mu_cond))
      newsamples <- mvnfast::rmvn(n_samples, mu=matrix(0,1,length(to_be_sampled)), sigma=as.matrix(cVar))
      samples[, idx_node] <- newsamples + cMU
    }

  }
  colnames(samples) <- cnms
  return(as.data.table(samples))
}


mec_importance_weights <- function(internal, dag_idx, id_combination, augmented_data) {

  # Get the current DAG and related information
  current_dag <- internal$parameters$dags[[dag_idx]]
  parents_list <- internal$parameters$parents_list_dags[[dag_idx]]
  topo_order <- internal$parameters$topological_orders_dags[[dag_idx]]

  X <- internal$objects$X

  # Get S/index_features/index_given from X
  this_id <- id_combination # because i cant do column_name == value where value == column_name
  index_given <- X[X$id_combination == this_id]$features[[1]]

  if (internal$parameters$data_type == "binary") {
    joint_probability_dt <- internal$parameters$binary.joint_prob_dt
    new_weights <- calculate_importance_weights_binary(augmented_data, index_given, joint_probability_dt, parents_list)
  } else {
    gaussian.cov_mat <- internal$parameters$gaussian.cov_mat
    gaussian.mu <- internal$parameters$gaussian.mu
    new_weights <- calculate_importance_weights_gaussian(augmented_data, index_given, gaussian.mu, gaussian.cov_mat, parents_list)
  }

  return(new_weights)
}