
mind_the_gap <- function(data, X, Z, W, Y, Shat = NULL, x0 = 0, x1 = 1, 
                         thr = 0.5, quant = FALSE, nboot = 1L, 
                         norm_method = "adapt") {
  
  if (nboot > 1L) {
    
    mtgs <- lapply(
      seq_len(nboot),
      function(i) {
        
        set.seed(i)
        boot_idx <- if (i == 1) TRUE else sample(nrow(data), replace = TRUE)
        boot_data <- data[boot_idx, ]
        mind_the_gap(boot_data, X, Z, W, Y, Shat, x0, x1, thr, quant, nboot = 1L, 
                     norm_method = norm_method)
      }
    )

    ret <- list()
    for (cmp in setdiff(names(mtgs[[1]]), c("u_de_base", "u_de_tot", "u_de_wgh"))) {

      ret[[cmp]] <- c(mtgs[[1]][[cmp]], 
                      sd(vapply(mtgs, function(x) x[[cmp]], numeric(1L))))  
    }
    for (cmp in c("u_de_base", "u_de_tot", "u_de_wgh"))
      ret[[cmp]] <- mtgs[[1]][[cmp]]
    
    return(ret)
  }
  
  trim_to_01 <- function(x) pmax(0, pmin(x, 1))
  thresh <- function(x, threshold) as.integer(x >= threshold)
  wgh_sum <- function(x, wgh) sum(x * wgh) / sum(wgh)
  
  compute_po <- function(fx = 0, wx = 0, zx = 0, thresh = FALSE, 
                         norm_method = "adapt") {
    
    if (is.na(zx)) zx <- -1
    
    if (wx == 0) wgh <- (1 - px_zw) / (1 - px_z) else wgh <- px_zw / px_z
    if (zx == 0) {
      wgh <- wgh * (1 - px_zw) / (1 - px)
    } else if (zx == 1) wgh <- wgh * px_z / px
    
    if (fx == 0) po_samp <- probx0 else po_samp <- probx1
    if (norm_method == "adapt") norm_const <- sum(wgh) else 
      norm_const <- length(po_samp)
    if (thresh) po_samp <- as.integer(po_samp > thr)
    list(po_samp = po_samp, wgh = wgh, norm_const = norm_const)
  }
  
  eval_po <- function(po) sum(po$po_samp * po$wgh) / po$norm_const
  diff_po <- function(po1, po2) eval_po(po1) - eval_po(po2)
  samp_imp <- function(po1, po2) {
    
    if (po1$norm_const != po2$norm_const) return(NA)
    po1$po_samp - po2$po_samp
  }
  
  if (is.logical(data[[Y]])) data[[Y]] <- as.integer(data[[Y]])
  assertthat::assert_that(all(data[[Y]] %in% c(0, 1)))
  
  
  # check if a predictor is already there
  if (!is.null(Shat)) {
    
    # build a predictor (xgboost) for the Shat (to determine the POs)
    dtrain <- xgb.DMatrix(data = as.matrix(data[, c(X, Z, W)]), 
                          label = data[[Shat]])
    
    # use the squared loss
    params <- list(objective = "reg:squarederror", eval_metric = "rmse")
  } else {

    # build a predictor (xgboost) for the binary outcome Y
    dtrain <- xgb.DMatrix(data = as.matrix(data[, c(X, Z, W)]), 
                          label = data[[Y]])
    
    # use the logistic loss
    params <- list(objective = "binary:logistic", eval_metric = "logloss")
  }
  
  # cross-validation to find optimal nrounds
  cv <- xgb.cv(
    params = params, data = dtrain, nfold = 5, nrounds = 100,
    early_stopping_rounds = 10, verbose = 0
  )
  optimal_nrounds <- cv$best_iteration
  
  # train the model
  mod <- xgb.train(
    params = params,
    data = dtrain,
    nrounds = optimal_nrounds
  )
  
  # obtained fitted values
  data_do_x0 <- data_do_x1 <- data
  data_do_x0[[X]] <- 0
  data_do_x1[[X]] <- 1
  
  prob <- if(is.null(Shat)) predict(mod, dtrain) else data[[Shat]]
  probx0 <- predict(mod, xgb.DMatrix(data = as.matrix(data_do_x0[, c(X, Z, W)])))
  probx1 <- predict(mod, xgb.DMatrix(data = as.matrix(data_do_x1[, c(X, Z, W)])))
  probx0 <- trim_to_01(probx0)
  probx1 <- trim_to_01(probx1)
  
  if (quant) {
    
    thr <- quantile(prob, probs = thr)
  }
  
  # get the propensity scores - regress X on Z, Z+W
  mod_xz <- glm(paste(X, "~", paste(Z, collapse = "+")), family = "binomial",
                data = data)
  px_z <- mod_xz$fitted.values
  mod_xzw <- glm(paste(X, "~", paste(c(Z, W), collapse = "+")), 
                 family = "binomial", data = data)
  px_zw <- mod_xzw$fitted.values
  px <- mean(data[[X]])
  
  # get f_{x_1, W_{x_0}}
  fx1wx0 <- compute_po(fx = 1, wx = 0, zx = NA, FALSE, norm_method)
  fx1wx0_x0 <- compute_po(fx = 1, wx = 0, zx = 0, FALSE, norm_method)
  
  rx1wx0 <- compute_po(fx = 1, wx = 0, zx = NA, TRUE, norm_method)
  rx1wx0_x0 <- compute_po(fx = 1, wx = 0, zx = 0, TRUE, norm_method)
  
  # get f_{x_0, W_{x_0}}
  fx0wx0 <- compute_po(fx = 0, wx = 0, zx = NA, FALSE, norm_method)
  fx0wx0_x0 <- compute_po(fx = 0, wx = 0, zx = 0, FALSE, norm_method)
  
  rx0wx0 <- compute_po(fx = 0, wx = 0, zx = NA, TRUE, norm_method)
  rx0wx0_x0 <- compute_po(fx = 0, wx = 0, zx = 0, TRUE, norm_method)
  
  # get f_{x_1, W_{x_1}}
  fx1wx1 <- compute_po(fx = 1, wx = 1, zx = NA, FALSE, norm_method)
  fx1wx1_x0 <- compute_po(fx = 1, wx = 1, zx = 0, FALSE, norm_method)
  rx1wx1 <- compute_po(fx = 1, wx = 1, zx = NA, TRUE, norm_method)
  rx1wx1_x0 <- compute_po(fx = 1, wx = 1, zx = 0, TRUE, norm_method)

  # get f | x0
  f_x0 <- mean(prob[data[[X]] == 0])
  fx0wx0_x0 <- compute_po(fx = 0, wx = 0, zx = 0, FALSE, norm_method)
  r_x0 <- mean(thresh(prob[data[[X]] == 0], thr))
  rx0wx0_x0 <- compute_po(fx = 0, wx = 0, zx = 0, TRUE, norm_method)
  
  # get f | x1
  f_x1 <- mean(prob[data[[X]] == 1])
  fx1wx1_x1 <- compute_po(fx = 1, wx = 1, zx = 1, FALSE, norm_method)
  r_x1 <- mean(thresh(prob[data[[X]] == 1], thr))
  rx1wx1_x1 <- compute_po(fx = 1, wx = 1, zx = 1, TRUE, norm_method)
  
  # natural effects
  nde_base = diff_po(fx1wx0, fx0wx0)
  nde_mtg = diff_po(rx1wx0, rx0wx0) - nde_base
  nie_base = diff_po(fx1wx0, fx1wx1)
  nie_mtg = diff_po(rx1wx0, rx1wx1) - nie_base
  nse_base = diff_po(fx1wx1_x1, fx1wx1) - diff_po(fx0wx0_x0, fx0wx0)
  nse_mtg = diff_po(rx1wx1_x1, rx1wx1) - diff_po(rx0wx0_x0, rx0wx0) - nse_base
  
  # counterfactual effects
  ctfde_base = diff_po(fx1wx0_x0, fx0wx0_x0)
  ctfde_mtg = diff_po(rx1wx0_x0, rx0wx0_x0) - ctfde_base
  ctfie_base = diff_po(fx1wx0_x0, fx1wx1_x0)
  ctfie_mtg = diff_po(rx1wx0_x0, rx1wx1_x0) - ctfie_base
  ctfse_base = diff_po(fx1wx1_x0, fx1wx1_x1)
  ctfse_mtg = diff_po(rx1wx1_x0, rx1wx1_x1) - ctfse_base
  
  list(
    nde_base = nde_base, nde_mtg = nde_mtg,
    nie_base = nie_base, nie_mtg = nie_mtg,
    nse_base = nse_base, nse_mtg = nse_mtg,
    ctfde_base = ctfde_base, ctfde_mtg = ctfde_mtg,
    ctfie_base = ctfie_base, ctfie_mtg = ctfie_mtg,
    ctfse_base = ctfse_base, ctfse_mtg = ctfse_mtg,
    u_de_base = samp_imp(fx1wx0, fx0wx0),
    u_de_tot = samp_imp(rx1wx0, rx0wx0),
    u_de_wgh = fx1wx0$wgh
  )
}

sample_influence <- function(total, base, wgh, effect = "DE") {
  
  data <- data.frame(total = total * wgh, base = base * wgh, 
                     mtg = (total - base) * wgh, 
                     sample = seq_along(total))
  data <- data[order(data$total, data$base, decreasing = TRUE),]
  data$total <- NULL
  data$order <- seq_len(nrow(data))
  ggplot(reshape2::melt(data, id.vars = c("sample", "order")), 
         aes(x = order, y = value, fill = variable)) +
    theme_bw() +
    geom_col(position = position_stack(reverse = TRUE), alpha = 0.5) +
    scale_fill_manual(values = c("#00BFC4", "#F8766D"),
                      labels = c(TeX("Outcome $Y$"), 
                                 TeX("Margin Complement $M$")),
                      name = "Component") +
    xlab("Sample Ranking") + ylab("Sample Influence") +
    theme(legend.position = "bottom") +
    guides(fill = guide_legend(override.aes = list(colour = "black")))
}

vis_mtg <- function(mtg, type = c("natural", "counterfactual"), dataset = "",
                    ret = "plot") {
  
  type <- match.arg(type, c("natural", "counterfactual"))
  
  srcwrap <- function(x) ifelse(x == "mimic", "MIMIC-IV", 
                                ifelse(x == "compas", "COMPAS", "Census 2018"))
  
  if (type == "natural") {
    
    data <- data.frame(
      rbind(mtg$nde_base, mtg$nde_mtg, mtg$nie_base, mtg$nie_mtg, mtg$nse_base, 
            mtg$nse_mtg),
      rep(c("Base", "MTG"), 3),
      rep(c("NDE", "NIE", "NSE"), each = 2)
    )
  } else if (type == "counterfactual") {
    
    data <- data.frame(
      rbind(mtg$ctfde_base, mtg$ctfde_mtg, mtg$ctfie_base, mtg$ctfie_mtg, 
            mtg$ctfse_base, mtg$ctfse_mtg),
      rep(c("Base", "MTG"), 3),
      rep(c("Ctf-DE", "Ctf-IE", "Ctf-SE"), each = 2)
    )
  }
  
  names(data) <- c("value", "sd", "Component", "effect")
  setDT(data)
  data <- data[order(effect, Component)]
  data[, `:=` (csum = cumsum(value)), by = effect]
  data[, `:=` (ymin = csum - sd, ymax = csum + sd)]

  p_eff <- ggplot(data, aes(x = effect, y = value, fill = Component)) +
    theme_bw() +
    geom_col(alpha = 0.5, color = "black", 
             position = position_stack(reverse = TRUE)) +
    geom_errorbar(aes(ymin = ymin, ymax = ymax), width = 0.4,
                  color = "darkgrey") +
    scale_fill_manual(values = c("#00BFC4", "#F8766D"),
                      labels = c(TeX("Outcome $Y$"), 
                                 TeX("Margin Complement $M$"))) +
    xlab("Causal Effect") + ylab("Effect Size") +
    theme(legend.position = "bottom") +
    ggtitle(paste0(srcwrap(dataset), " Decomposition"))
  
  p_samps <- sample_influence(mtg$u_de_tot, mtg$u_de_base, mtg$u_de_wgh) +
             ggtitle("DE Sample Influences")
  
  if (ret == "plot") 
    cowplot::plot_grid(p_eff, p_samps, ncol = 2L) else list(p_eff, p_samps)
}

mtg_wrap <- function(dataset) {
  
  if (dataset == "compas") {
    
    c(data, sfm) %<-% preproc_compas()
    
  } else if (dataset == "census") {
    
    c(data, sfm) %<-% preproc_census(gov_census, SFM_proj("census"))
    data <- as.data.frame(data)[1:10^4,]
  } else if (dataset == "mimic") {
    
    c(data, sfm) %<-% preproc_miiv()
    data <- as.data.frame(data)[1:10^4,]
  }
  
  mtg <- mind_the_gap(X = sfm$X, Z = sfm$Z, W = sfm$W, Y = sfm$Y, 
                      data = data, x0 = 0, x1 = 1, nboot = 50L, 
                      thr = 0.5, quant = (dataset == "mimic"))
  mtg
}
