#' @rdname setup_approach
#'
#' @param binary.joint_prob_dt Data.table. (Optional)
#' Containing the joint probability distribution for each combination of feature
#' values.
#' `NULL` means it is estimated from the `x_train` and `x_explain`.
#'
#' @param binary.epsilon Numeric value. (Optional)
#' If \code{joint_probability_dt} is not supplied, probabilities/frequencies are
#' estimated using `x_train`. If certain observations occur in `x_train` and NOT in `x_explain`,
#' then epsilon is used as the proportion of times that these observations occurs in the training data.
#' In theory, this proportion should be zero, but this causes an error later in the Shapley computation.
#'
#' @inheritParams default_doc_explain
#'
#' @export
setup_approach.conditional_binary <- function(internal,
                                              binary.joint_prob_dt = NULL,
                                              binary.epsilon = 0.001,
                                              ...) {
  defaults <- mget(c("binary.joint_prob_dt", "binary.epsilon"))
  internal <- insert_defaults(internal, defaults)

  joint_probability_dt <- internal$parameters$binary.joint_prob_dt
  epsilon <- internal$parameters$epsilon

  feature_names <- internal$parameters$feature_names
  feature_specs <- internal$objects$feature_specs

  x_train <- internal$data$x_train
  x_explain <- internal$data$x_explain

  # The creation of joint_probability_dt modifies the order of x_explain; we save the order to be able to restore
  # it afterwards
  x_explain_save <- as.matrix(x_explain)

  # estimate joint_prob_dt if it is not passed to the function
  if (is.null(joint_probability_dt)) {
    joint_prob_dt0 <- x_train[, .N, eval(feature_names)]

    x_explain_dt <- data.table::setDT(x_explain)

    explain_not_in_train <- data.table::setkeyv(x_explain_dt, feature_names)[!x_train]

    N_explain_not_in_train <- nrow(unique(explain_not_in_train))

    if (N_explain_not_in_train > 0) {
      joint_prob_dt0 <- rbind(joint_prob_dt0, cbind(explain_not_in_train, N = binary.epsilon))
    }

    joint_prob_dt0[, joint_prob := N / .N]
    joint_prob_dt0[, joint_prob := joint_prob / sum(joint_prob)]
    data.table::setkeyv(joint_prob_dt0, feature_names)

    joint_probability_dt <- joint_prob_dt0[, N := NULL][, id_all := .I]
  }

  internal$parameters$binary.joint_prob_dt <- joint_probability_dt

  # After operations, reorder x_explain by the original order
  #data.table::setorder(x_explain, "original_order")

  # Optionally, drop the original_order column if no longer needed
  #x_explain[, original_order := NULL]
  internal$data$x_explain <- x_explain_save

  return(internal)
}


#' @inheritParams default_doc
#'
#' @rdname prepare_data
#' @export
#' @keywords internal
prepare_data.conditional_binary <- function(internal, index_features = NULL, ...) {
  x_train <- internal$data$x_train
  x_explain <- internal$data$x_explain
  n_samples <- internal$parameters$n_samples
  n_explain <- nrow(x_explain)

  joint_probability_dt <- internal$parameters$binary.joint_prob_dt

  X <- internal$objects$X
  S <- internal$objects$S

  if (is.null(index_features)) { # 2,3
    features <- X$features # list of [1], [2], [2, 3]
  } else {
    features <- X$features[index_features] # list of [1],
  }
  feature_names <- internal$parameters$feature_names

  x_explain0 <- as.matrix(x_explain)

  dt_l <- list()


  for (i in seq_len(n_explain)) {
    l <- lapply(
      X = features,
      FUN = sample_conditional_binary,
      n_samples = n_samples,
      joint_probability_dt = joint_probability_dt,
      x_explain = x_explain0[i, , drop = FALSE],
      data = x_train
    )

    dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination")
    dt_l[[i]][, w := 1 / n_samples]
    dt_l[[i]][, id := i]
    # Sum weights of rows with equal features
    dt_l[[i]] <- dt_l[[i]][, .(w = sum(w)), by = c("id_combination", feature_names, "id")]


    if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]]

  }



  dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE)
  return(dt)
}

sample_conditional_binary <- function(index_given, n_samples, x_explain, joint_probability_dt,
                                 method = 'empirical', data = NULL) {


  # Handles the unconditional and full conditional separately when predicting
  cnms <- colnames(x_explain)
  m <- length(cnms)
  if (length(index_given) %in% c(0, length(x_explain))) {
    return(data.table::as.data.table(x_explain))
  }

  S <- index_given
  S_names <- cnms[S]
  notS <- setdiff(seq_along(x_explain), S)
  notS_names <- cnms[notS]

  # # Initialize matrix of samples
  samples <- matrix(NA, ncol = m, nrow = n_samples)
  samples[, index_given] <- matrix(x_explain[index_given], nrow = n_samples, ncol = length(index_given), byrow = TRUE)
  colnames(samples) <- cnms


  # Filter rows where features at index_given match x_explain
  x_explain_df <- as.data.frame(x_explain)
  colnames(x_explain_df) <- cnms
  joint_probability_dt0 <- joint_probability_dt[.(x_explain_df[index_given]), on = cnms[index_given]]

  # Add column with conditional probability
  joint_probability_dt0[, conditional_prob := joint_prob / sum(joint_prob)]

  # Sample from the conditional distribution
  samples[, -index_given] <- as.matrix(joint_probability_dt0[sample(.N, n_samples, prob = conditional_prob, replace = TRUE), .SD, .SDcols = notS_names])

  return(as.data.table(samples))
  # notS <- setdiff(seq_along(x_explain), index_given)
  #
  # # Filter joint_probability_dt where features at index_given match x_explain[S]
  # S <- index_given
  # S_values <- setNames(as.list(x_explain[S]), cnms[S])
  # joint_probability_dt0 <- joint_probability_dt[.(x_explain[index_given]), on = index_given]
  #
  # # Compute the conditional probability of the features not in index_given
  # joint_probability_dt0[, conditional_prob := joint_prob / sum(joint_prob), by = notS]
  #
  # # Sample from the conditional distribution
  # samples <- joint_probability_dt0[, sample(.I, size = n_samples, replace = TRUE, prob = conditional_prob), by = notS]$V1





}
