#' Computes `v(S)` for all features subsets `S`.
#'
#' @inheritParams default_doc
#' @inheritParams explain
#'
#' @param method Character
#' Indicates whether the lappy method (default) or loop method should be used.
#'
#' @export
compute_vS <- function(internal, model, predict_model, method = "future") {
  S_batch <- internal$objects$S_batch

  if (method == "future") {
    ret <- future_compute_vS_batch(
      S_batch = S_batch,
      internal = internal,
      model = model,
      predict_model = predict_model
    )
  } else {
    # Doing the same as above without future without progressbar or paralellization
    ret <- list()
    for (i in seq_along(S_batch)) {
      S <- S_batch[[i]]

      ret[[i]] <- batch_compute_vS(
        S = S,
        internal = internal,
        model = model,
        predict_model = predict_model
      )
    }
  }

  return(ret)
}

future_compute_vS_batch <- function(S_batch, internal, model, predict_model) {
  if (requireNamespace("progressr", quietly = TRUE)) {
    p <- progressr::progressor(sum(lengths(S_batch)))
  } else {
    p <- NULL
  }

  ret <- future.apply::future_lapply(
    X = S_batch,
    FUN = batch_compute_vS,
    internal = internal,
    model = model,
    predict_model = predict_model,
    p = p,
    future.seed = internal$parameters$seed,
    future.globals = c("generate_data_under_intervention",
                       "generate_data_under_intervention.binary_scm",
                       "generate_data_under_intervention.linear_gaussian_scm",
                       "generate_data_under_intervention.additive_noise_scm"
                      )
  )
  return(ret)
}


#' @keywords internal
batch_compute_vS <- function(S, internal, model, predict_model, p = NULL) {
  library(igraph, quietly = TRUE) # Temporary fix to handle future::future_lapply and devtools::load_all
  keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
  feature_names <- internal$parameters$feature_names
  type <- internal$parameters$type
  horizon <- internal$parameters$horizon
  n_endo <- internal$data$n_endo
  output_size <- internal$parameters$output_size
  explain_idx <- internal$parameters$explain_idx
  explain_lags <- internal$parameters$explain_lags
  y <- internal$data$y
  xreg <- internal$data$xreg
  approach <- internal$parameters$approach

  # For now this is the easiest way to integrate our approach into existing code
  if (approach == "mec") {
    dt_vS <- batch_compute_vS_mec(
      S = S,
      internal = internal,
      model = model,
      predict_model = predict_model,
      p = p
    )
    return(dt_vS)
  }

  dt <- batch_prepare_vS(S = S, internal = internal) # Make it optional to store and return the dt_list

  pred_cols <- paste0("p_hat", seq_len(output_size))

  compute_preds(
    dt, # Updating dt by reference
    feature_names = feature_names,
    predict_model = predict_model,
    model = model,
    pred_cols = pred_cols,
    type = type,
    horizon = horizon,
    n_endo = n_endo,
    explain_idx = explain_idx,
    explain_lags = explain_lags,
    y = y,
    xreg = xreg
  )
  dt_vS <- compute_MCint(dt, pred_cols)
  if (!is.null(p)) {
    p(
      amount = length(S),
      message = "Estimating v(S)"
    ) # TODO: Add a message to state what batch has been computed
  }

  if (keep_samp_for_vS) {
    return(list(dt_vS = dt_vS, dt_samp_for_vS = dt))
  } else {
    return(dt_vS = dt_vS)
  }
}

#' TODO: add docs
#' @export
batch_compute_vS_mec <- function(S, internal, model, predict_model, p = NULL) {
  keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
  feature_names <- internal$parameters$feature_names
  type <- internal$parameters$type
  output_size <- internal$parameters$output_size

  compute_for_one_s <- function(id_combination) {
    #id_combination <- list(id_combination)

    idx_id_combination <- which(internal$objects$X[["id_combination"]] == id_combination)
    dags_to_use <- unique(internal$parameters$intervention_equivalence_map[idx_id_combination,])
    list_dt_vS <- list()
    first_iteration_done <- FALSE
    use_iw <- internal$parameters$causal_approximation_method == "iw"

    # if(length(dags_to_use) == 1) {
    #   # do nothing
    #   #print('test')
    # }

    for (i in seq_along(dags_to_use)) {
      dag_idx <- dags_to_use[i]

      #internal$parameters$current_dag_idx <- dag_idx

      if (!use_iw | !first_iteration_done) {
        dt <- batch_prepare_vS_mec(S = id_combination, internal = internal, dag_idx = dag_idx)
        pred_cols <- paste0("p_hat", seq_len(output_size))
        compute_preds(dt, feature_names, predict_model, model, pred_cols, type)
      } else {
        if (use_iw) {
          dt[, w := mec_importance_weights(internal, dag_idx, id_combination, .SD),
             by = .(id_combination, id),
             .SDcols = feature_names]
        }
      }
      first_iteration_done <- TRUE
      dt_vS <- compute_MCint(dt, pred_cols)
      dt_vS[, dag_idx := dag_idx]
      list_dt_vS[[i]] <- dt_vS
    }
    data.table::rbindlist(list_dt_vS, use.names = TRUE, fill = TRUE)
  }

  # Handling multiple idx_combinations
  if (length(S) > 1) {
    results <- lapply(S, compute_for_one_s)
    dt_vS <- data.table::rbindlist(results, use.names = TRUE, fill = TRUE)
  } else {
    dt_vS <- compute_for_one_s(S[[1]])
  }

  return(dt_vS)

}

#' @keywords internal
batch_prepare_vS <- function(S, internal) {
  max_id_combination <- internal$parameters$n_combinations
  x_explain <- internal$data$x_explain
  n_explain <- internal$parameters$n_explain

  # TODO: Check what is the fastest approach to deal with the last observation.
  # Not doing this for the largest id combination (should check if this is faster or slower, actually)
  # An alternative would be to delete rows from the dt which is provided by prepare_data.
  if (!(max_id_combination %in% S)) {
    # TODO: Need to handle the need for model for the AIC-versions here (skip for Python)
    dt <- prepare_data(internal, index_features = S)
  } else {
    if (length(S) > 1) {
      S <- S[S != max_id_combination]
      dt <- prepare_data(internal, index_features = S)
    } else {
      dt <- NULL # Special case for when the batch only include the largest id
    }
    dt_max <- data.table(id_combination = max_id_combination, x_explain, w = 1, id = seq_len(n_explain))
    dt <- rbind(dt, dt_max)
    setkey(dt, id, id_combination)
  }
  return(dt)
}

#' @keywords internal
batch_prepare_vS_mec <- function(S, internal, dag_idx) {
  max_id_combination <- internal$parameters$n_combinations
  x_explain <- internal$data$x_explain
  n_explain <- internal$parameters$n_explain

  # TODO: Check what is the fastest approach to deal with the last observation.
  # Not doing this for the largest id combination (should check if this is faster or slower, actually)
  # An alternative would be to delete rows from the dt which is provided by prepare_data.
  if (!(max_id_combination %in% S)) {
    # TODO: Need to handle the need for model for the AIC-versions here (skip for Python)
    dt <- prepare_data(internal, index_features = S, current_dag_idx = dag_idx)
  } else {
    if (length(S) > 1) {
      S <- S[S != max_id_combination]
      dt <- prepare_data(internal, index_features = S, current_dag_idx = dag_idx)
    } else {
      dt <- NULL # Special case for when the batch only include the largest id
    }
    dt_max <- data.table(id_combination = max_id_combination, x_explain, w = 1, id = seq_len(n_explain))
    dt <- rbind(dt, dt_max)
    setkey(dt, id, id_combination)
  }
  return(dt)
}

#' @keywords internal
compute_preds <- function(
  dt,
  feature_names,
  predict_model,
  model,
  pred_cols,
  type,
  horizon = NULL,
  n_endo = NULL,
  explain_idx = NULL,
  explain_lags = NULL,
  y = NULL,
  xreg = NULL) {
  # Predictions

  if (type == "forecast") {
    dt[, (pred_cols) := predict_model(
      x = model,
      newdata = .SD[, 1:n_endo],
      newreg = .SD[, -(1:n_endo)],
      horizon = horizon,
      explain_idx = explain_idx[id],
      explain_lags = explain_lags,
      y = y,
      xreg = xreg
    ), .SDcols = feature_names]
  } else {
    dt[, (pred_cols) := predict_model(model, newdata = .SD), .SDcols = feature_names]
  }

  return(dt)
}

compute_MCint <- function(dt, pred_cols = "p_hat") {
  # Calculate contributions
  dt_res <- dt[, lapply(.SD, function(x) sum(((x) * w) / sum(w))), .(id, id_combination), .SDcols = pred_cols]
  data.table::setkeyv(dt_res, c("id", "id_combination"))
  dt_mat <- data.table::dcast(dt_res, id_combination ~ id, value.var = pred_cols)
  if (length(pred_cols) == 1) {
    names(dt_mat)[-1] <- paste0(pred_cols, "_", names(dt_mat)[-1])
  }
  # dt_mat[, id_combination := NULL]

  dt_mat
}

