#' @rdname setup_approach
#'
#' @param binary.joint_prob_dt Data.table. (Optional)
#' Containing the joint probability distribution for each combination of feature
#' values.
#' `NULL` means it is estimated from the `x_train` and `x_explain`.
#'
#' @param binary.epsilon Numeric value. (Optional)
#' If \code{joint_probability_dt} is not supplied, probabilities/frequencies are
#' estimated using `x_train`. If certain observations occur in `x_train` and NOT in `x_explain`,
#' then epsilon is used as the proportion of times that these observations occurs in the training data.
#' In theory, this proportion should be zero, but this causes an error later in the Shapley computation.
#'
#' @inheritParams default_doc_explain
#'
#' @export
setup_approach.causal_binary <- function(internal,
                                       binary.joint_prob_dt = NULL,
                                       binary.epsilon = 0.001,
                                       ...) {
  defaults <- mget(c("binary.joint_prob_dt", "binary.epsilon"))
  internal <- insert_defaults(internal, defaults)

  joint_probability_dt <- internal$parameters$binary.joint_prob_dt
  epsilon <- internal$parameters$epsilon

  feature_names <- internal$parameters$feature_names
  feature_specs <- internal$objects$feature_specs

  x_train <- internal$data$x_train
  x_explain <- internal$data$x_explain

  # The creation of joint_probability_dt modifies the order of x_explain; we save the order to be able to restore
  # it afterwards
  x_explain_save <- as.matrix(x_explain)

  # estimate joint_prob_dt if it is not passed to the function
  if (is.null(joint_probability_dt)) {
    joint_prob_dt0 <- x_train[, .N, eval(feature_names)]

    x_explain_dt <- data.table::setDT(x_explain)

    explain_not_in_train <- data.table::setkeyv(x_explain_dt, feature_names)[!x_train]

    N_explain_not_in_train <- nrow(unique(explain_not_in_train))

    if (N_explain_not_in_train > 0) {
      joint_prob_dt0 <- rbind(joint_prob_dt0, cbind(explain_not_in_train, N = binary.epsilon))
    }

    joint_prob_dt0[, joint_prob := N / .N]
    joint_prob_dt0[, joint_prob := joint_prob / sum(joint_prob)]
    data.table::setkeyv(joint_prob_dt0, feature_names)

    joint_probability_dt <- joint_prob_dt0[, N := NULL][, id_all := .I]
  }

  internal$parameters$binary.joint_prob_dt <- joint_probability_dt

  # After operations, reorder x_explain by the original order
  #data.table::setorder(x_explain, "original_order")

# Optionally, drop the original_order column if no longer needed
  #x_explain[, original_order := NULL]
  internal$data$x_explain <- x_explain_save


  # Prepare parents list and topological order
  # Which will be re-used for each sample
  if (is.null(internal$parameters$parents_list)) {
    internal$parameters$parents_list <- lapply(igraph::V(internal$parameters$causal_dag), function(node) {
      igraph::neighbors(internal$parameters$causal_dag, node, mode = "in")
    })
  }

  if (is.null(internal$parameters$topo_order)) {
    internal$parameters$topo_order <- igraph::topo_sort(internal$parameters$causal_dag)
  }

  if (is.null(internal$parameters$causal_approximation_method)) {
    internal$parameters$causal_approximation_method <- "sample"
  }

  return(internal)
}


#' @inheritParams default_doc
#'
#' @rdname prepare_data
#' @export
#' @keywords internal
prepare_data.causal_binary <- function(internal, index_features = NULL, ...) {
  x_train <- internal$data$x_train
  x_explain <- internal$data$x_explain
  n_samples <- internal$parameters$n_samples
  n_explain <- nrow(x_explain)

  joint_probability_dt <- internal$parameters$binary.joint_prob_dt

  X <- internal$objects$X
  S <- internal$objects$S

  causal_dag <- internal$parameters$causal_dag
  parents_list <- internal$parameters$parents_list
  topo_order <- internal$parameters$topo_order

  approximation_method <- internal$parameters$causal_approximation_method

  if (is.null(index_features)) { # 2,3
    features <- X$features # list of [1], [2], [2, 3]
  } else {
    features <- X$features[index_features] # list of [1],
  }
  feature_names <- internal$parameters$feature_names

  x_explain0 <- as.matrix(x_explain)

  dt_l <- list()

  if (approximation_method == "sample") {
    for (i in seq_len(n_explain)) {
    l <- lapply(
      X = features,
      FUN = sample_causal_binary,
      n_samples = n_samples,
      joint_probability_dt = joint_probability_dt,
      x_explain = x_explain0[i, , drop = FALSE],
      parents_list = parents_list,
      topo_order = topo_order,
      data = x_train
    )

    dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
    dt_l[[i]][, w := 1 / n_samples]
    dt_l[[i]][, id := i]
    # Sum weights of rows with equal features
    dt_l[[i]] <- dt_l[[i]][, .(w = sum(w)), by = c("id_combination", feature_names, "id")]


    if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
    }
  } else if (approximation_method == "iw") {
    for (i in seq_len(n_explain)) {
      l <- lapply(
        X = features,
        FUN = augment_and_weight_binary,
        training_data = x_train,
        x_to_explain = x_explain0[i, , drop = FALSE],
        joint_probability_dt = joint_probability_dt,
        parents_list = parents_list,
        dag = causal_dag
      )

      dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
      dt_l[[i]][, id := i]
      if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]
    }
  }



  dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE)
  return(dt)



  # # 3 id columns: id, id_combination, and id_all
  # # id: for each x_explain observation
  # # id_combination: the rows of the S matrix
  # # id_all: identifies the unique combinations of feature values from
  # # the training data (not necessarily the ones in the explain data)
  #
  #
  # feature_conditioned <- paste0(feature_names, "_conditioned")
  # feature_conditioned_id <- c(feature_conditioned, "id")
  #
  # S_dt <- data.table::data.table(S)
  # S_dt[S_dt == 0] <- NA
  # S_dt[, id_combination := seq_len(nrow(S_dt))]
  #
  # data.table::setnames(S_dt, c(feature_conditioned, "id_combination"))
  #
  # # (1) Compute marginal probabilities
  #
  # # multiply table of probabilities nrow(S) times
  # joint_probability_mult <- joint_probability_dt[rep(id_all, nrow(S))]
  #
  # data.table::setkeyv(joint_probability_mult, "id_all")
  # j_S_dt <- cbind(joint_probability_mult, S_dt) # combine joint probability and S matrix
  #
  # j_S_feat <- as.matrix(j_S_dt[, feature_names, with = FALSE]) # with zeros
  # j_S_feat_cond <- as.matrix(j_S_dt[, feature_conditioned, with = FALSE])
  #
  # j_S_feat[which(is.na(j_S_feat_cond))] <- NA # with NAs
  # j_S_feat_with_NA <- data.table::as.data.table(j_S_feat)
  #
  # # now we have a data.table with the conditioned
  # # features and the feature value but no ids
  # data.table::setnames(j_S_feat_with_NA, feature_conditioned)
  #
  # j_S_no_conditioned_features <- data.table::copy(j_S_dt)
  # j_S_no_conditioned_features[, (feature_conditioned) := NULL]
  #
  # # dt with conditioned features (correct values) + ids + joint_prob
  # j_S_all_feat <- cbind(j_S_no_conditioned_features, j_S_feat_with_NA) # features match id_all
  #
  # # compute all marginal probabilities
  # marg_dt <- j_S_all_feat[, .(marg_prob = sum(joint_prob)), by = feature_conditioned]
  #
  # # (2) Compute conditional probabilities
  #
  # cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned]
  # cond_dt[, cond_prob := joint_prob / marg_prob]
  # cond_dt[id_combination == 1, marg_prob := 0]
  # cond_dt[id_combination == 1, cond_prob := 1]
  #
  # # check marginal probabilities
  # cond_dt_unique <- unique(cond_dt, by = feature_conditioned)
  # check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)),
  #   by = "id_combination"
  # ][["sum_prob"]]
  # if (!all(round(check) == 1)) {
  #   print("Warning - not all marginal probabilities sum to 1. There could be a problem
  #         with the joint probabilities. Consider checking.")
  # }
  #
  # # make x_explain
  # data.table::setkeyv(cond_dt, c("id_combination", "id_all"))
  # x_explain_with_id <- data.table::copy(x_explain)[, id := .I]
  # dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names, allow.cartesian = TRUE]
  #
  # # this is a really important step to get the proper "w" which will be used in compute_preds()
  # dt_explain_just_conditioned <- dt_just_explain[, feature_conditioned_id, with = FALSE]
  #
  # cond_dt[, id_all := NULL]
  # dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE]
  #
  # # check conditional probabilities
  # check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)),
  #   by = c("id_combination", "id")
  # ][["sum_prob"]]
  # if (!all(round(check) == 1)) {
  #   print("Warning - not all conditional probabilities sum to 1. There could be a problem
  #         with the joint probabilities. Consider checking.")
  # }
  #
  # setnames(dt, "cond_prob", "w")
  # data.table::setkeyv(dt, c("id_combination", "id"))
  #
  # # here we merge so that we only return the combintations found in our actual explain data
  # # this merge does not change the number of rows in dt
  # # dt <- merge(dt, x$X[, .(id_combination, n_features)], by = "id_combination")
  # # dt[n_features %in% c(0, ncol(x_explain)), w := 1.0]
  # dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0]
  # ret_col <- c("id_combination", "id", feature_names, "w")
  # return(dt[id_combination %in% index_features, mget(ret_col)])
}

sample_causal_binary <- function(index_given, n_samples, x_explain, joint_probability_dt,
                                 parents_list, topo_order, method = 'empirical', data = NULL) {


  # Handles the unconditional and full conditional separately when predicting
  cnms <- colnames(x_explain)
  m <- length(cnms)
  if (length(index_given) %in% c(0, length(x_explain))) {
    return(data.table::as.data.table(x_explain))
  }

  intervention <- x_explain[index_given]

  # Initialize matrix of samples
  samples <- matrix(NA, ncol = m, nrow = n_samples)
  samples[, index_given] <- matrix(intervention, nrow = n_samples, ncol = length(index_given), byrow = TRUE)
  colnames(samples) <- cnms

  for (idx_current_node in topo_order) {
    if (idx_current_node %in% index_given) {
      # If the node is intervened on, we do not need to sample from it
      # todo: remove before using intersection() or something
      next
    }

    current_node_name <- cnms[idx_current_node]
    current_node_parents <- parents_list[[idx_current_node]]
    current_node_parent_names <- names(current_node_parents)
    current_node_parent_idxs <- as.numeric(parents_list[[idx_current_node]])

    # Sample from marginal if node has no parents
    if (length(current_node_parents) == 0) {
      marginal_probs <- joint_probability_dt[, .(marg_prob = sum(joint_prob)), by = current_node_name]
      marginal_prob_one <- marginal_probs[get(current_node_name) == 1, marg_prob] # Marginal p(node = 1)
      samples[, idx_current_node] <- rbinom(n_samples, 1, marginal_prob_one)
    } # Sample from conditional distribution current node given current_node_parents
    else {
      if (method == 'logistic') {

        formula_string <- paste0(current_node_name, " ~ ", paste(current_node_parent_names, collapse = " + "))
        logit_node_given_parents <- glm(as.formula(formula_string), family = binomial(link='logit'),
                                        data=as.data.frame(data)[, c(idx_node, current_node_parent_idxs)])

        parent_samples <- samples[, current_node_parent_idxs, drop = FALSE]

        p_one <- predict(logit_node_given_parents, as.data.frame(parent_samples), type = "response")
      }
      else {
        # Get conditional probabilities p(node | parents) by marginalizing out other variables
        # This is done by summing over the joint probability table
        current_node_and_parents_names <- c(current_node_name, current_node_parent_names)
        joint_probs_current_node_and_parents <- joint_probability_dt[, .(joint_prob = sum(joint_prob)), by = current_node_and_parents_names]
        marginal_probs_current_parents <- joint_probability_dt[, .(marg_prob = sum(joint_prob)), by = current_node_parent_names]
        conditional_prob_dt <- merge(joint_probs_current_node_and_parents, marginal_probs_current_parents, by = current_node_parent_names)
        conditional_prob_dt[, cond_prob := joint_prob / marg_prob]

        # Merge conditional probabilities with the samples of the parents
        conditional_prob_samples <- merge(conditional_prob_dt, as.data.table(samples[, current_node_parent_idxs, drop = FALSE]),
                                          by = current_node_parent_names, allow.cartesian = TRUE)
        # Get rows where node value is 1
        p_one <- conditional_prob_samples[get(current_node_name) == 1][["cond_prob"]]

      }

      samples[, idx_current_node] <- runif(length(p_one)) <= p_one

    }

  }
  return(as.data.table(samples))
}

calculate_ip_weights_causal_binary <- function(augmented_data, index_given, x_explain,
                                               joint_probability_dt, parents_list) {
  intervention_size <- length(index_given)
  index_not_given <- setdiff(seq_len(length(x_explain)), index_given)

  # We use different characterizations of the interventional distribution
  # based on the size of the distribution
  if (intervention_size > as.integer(length(x_explain) / 2)) {

    # log numerator: p(x)

    # join the augmented data with the joint probability table
    # to get the joint probability of the observed features
    joint_prob_augmented_data <- joint_probability_dt[augmented_data, on = names(augmented_data), allow.cartesian = TRUE]
    log_prob_numerator <- log(joint_prob_augmented_data[["joint_prob"]])

    # log denominator: p(x_notS) prod_{i in S} p(x_i | x_notS)

    # sum over features in S
    # then join this table with the observed values not in S

    not_S_marginal_prob <- joint_probability_dt[, .(marg_prob = sum(joint_prob)), by = colnames(x_explain)[index_not_given]]
    log_prob_observed_not_S <- log(not_S_marginal_prob[as.data.table(augmented_data), on = names(augmented_data), allow.cartesian = TRUE][["marg_prob"]])
  }

}


augment_and_weight_binary <- function(training_data, x_to_explain, index_given, joint_probability_dt, parents_list, dag) {
  augmented_data <- augment_data(training_data, x_to_explain, index_given)
  weights <- calculate_importance_weights_binary(augmented_data, index_given, joint_probability_dt, parents_list)

  intervention <- x_to_explain[index_given]
  names(intervention) <- as.numeric(igraph::V(dag)[index_given])

  augmented_data[, w := weights]

  return(augmented_data)
}

augment_data <- function(training_data, x_to_explain, index_given) {
  # Replace columns of training data with columns of x_to_explain for the given indices
  augmented_data <- data.table::copy(training_data)
  #augmented_data[[index_given]] <- matrix(x_to_explain[, index_given, drop=FALSE], nrow = nrow(augmented_data), ncol = length(index_given), byrow = TRUE)

  # Replace column values of columns in index_given with x_to_explain
  # for (i in seq_along(index_given)) {
  #   data.table::set(augmented_data, j = index_given[i], value = x_to_explain[i])
  # }

  for (i in seq_along(index_given)) {
    augmented_data[, (index_given[i]) := as.data.table(matrix(x_to_explain[index_given[i]], nrow = nrow(augmented_data), ncol = 1, byrow = TRUE))]
  }

  return(augmented_data)
}

calculate_importance_weights_binary <- function(augmented_data, index_given, joint_probability_dt,
                                                parents_list) {


  n_samples <- nrow(augmented_data)

  # Log p(x_notS|do(S)) = sum_{j in notS} log p(x_j|pa_j)
  log_numerator <- vector("numeric", length = nrow(augmented_data))
  for (idx_current_node in seq(ncol(augmented_data))) {
    # Skip nodes that are intervened on
    if (idx_current_node %in% index_given) {
      next
    }

    current_node_name <- colnames(augmented_data)[idx_current_node]
    current_node_parents <- parents_list[[idx_current_node]]
    current_node_parent_names <- names(current_node_parents)

    # Log density of marginal if node has no parents
    if (length(current_node_parents) == 0) {
      # Get marginal probabilities of current node
      marginal_probs <- joint_probability_dt[, .(marg_prob = sum(joint_prob)), by = current_node_name]

      # Join the augmented data with the marginal probabilities
      log_numerator <- log_numerator + log(marginal_probs[as.data.table(augmented_data), on = current_node_name, allow.cartesian = TRUE][["marg_prob"]])
    }
    else {

      # Get conditional probabilities p(node | parents) by marginalizing out other variables
      current_node_and_parents_names <- c(current_node_name, current_node_parent_names)
      joint_probs_current_node_and_parents <- joint_probability_dt[, .(joint_prob = sum(joint_prob)), by = current_node_and_parents_names]
      marginal_probs_current_parents <- joint_probability_dt[, .(marg_prob = sum(joint_prob)), by = current_node_parent_names]
      conditional_prob_dt <- merge(joint_probs_current_node_and_parents, marginal_probs_current_parents, by = current_node_parent_names)
      conditional_prob_dt[, cond_prob := joint_prob / marg_prob]

      # merge conditional probabilities with the node in the augmented data
      conditional_prob_samples <- data.table::merge.data.table(conditional_prob_dt, as.data.table(augmented_data[, current_node_name, with = FALSE]),
                                            by = current_node_name, allow.cartesian = TRUE)

      log_numerator <- log_numerator + log(conditional_prob_samples[["cond_prob"]])
    }
  }

  # log p(X_notS)
  log_denominator <- log(sum(joint_probability_dt[, .(marg_prob = sum(joint_prob)), by = eval(colnames(augmented_data)[index_given])][["marg_prob"]]))
  weights <- exp(log_numerator - log_denominator)
  return(weights)
}