# Multi-CATE simulations
# Setup: Large observational study (OS) and small RCT data (CT)
# Continuous outcome

# Scenarios as defined in setups.R:

# s1d_1:
# dim: 10 (normal)
# pcscore [train, test]: osSparse2Linear
# pcscore [rct]: rct5
# mu0: sparseNonLinear3
# tau: sparseNonLinear3
# external shift: SparseLogitLinear1

# s1d_2a:
# dim: 10 (normal)
# pcscore [train, test]: osConfounding2
# pcscore [rct]: rct5
# mu0: sparseLinearWeak
# tau: sparseLinearWeak
# external shift: SparseLogitLinear2
# Yobsconf in RCT and train, tauconf in test

# s1d_2ax:
# dim: 10 (unif)
# pcscore [train, test]: osConfounding2
# pcscore [rct]: rct5
# mu0: sparseLinearWeak
# tau: fullLocallyLinear
# external shift: SparseLogitLinear2
# Yobsconf in RCT and train, tauconf in test

# s1d_2b: 
# dim: 10 (normal)
# pcscore [train, test]: osConfounding2
# pcscore [rct]: rct5
# mu0: sparseLinearWeak
# tau: sparseLinearWeak
# external shift: SparseLogitLinear2
# Yobs in RCT, Yobsconf in train, tau in test ("total shift")

# s1d_2bx:
# dim: 10 (unif)
# pcscore [train, test]: osConfounding2
# pcscore [rct]: rct5
# mu0: sparseLinearWeak
# tau: fullLocallyLinear
# external shift: SparseLogitLinear2
# Yobs in RCT, Yobsconf in train, tau in test ("total shift")

library(tidyverse)
library(ranger)
library(grf)
library(mcboost)
library(mlr3learners)
library(kableExtra)
library(ggh4x)
#library(causalToolbox)
source("setups_num.R")
source("dr_learner.R")
source("KL.R")

## Prepare loop

train_res <- tibble()
train_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA, flag = NA, 
                     ps = list(tibble(ps = NA)),
                     KL = list(tibble(KL = NA)),
                     y = list(tibble(y = NA)),
                     ate = NA, tau = list(tibble(tau = NA)),
                     y_cforest_tr = list(tibble(y_cforest_tr = NA)),
                     y_slearner_tr = list(tibble(y_slearner_tr = NA)),
                     yp_slearner_tr_mcr = list(tibble(yp_slearner_tr_mcr = NA)),
                     y_slearner_tr_mcr = list(tibble(y_slearner_tr_mcr = NA)),
                     y_tlearner_tr = list(tibble(y_tlearner_tr = NA)),
                     y_tclearner_tr = list(tibble(y_tclearner_tr = NA)),
                     yp_tlearner_tr_mcr = list(tibble(yp_tlearner_tr_mcr = NA)),
                     y_tlearner_tr_mcr = list(tibble(y_tlearner_tr_mcr = NA)),
                     yp_tlearner_tr_mct = list(tibble(yp_tlearner_tr_mct = NA)),
                     y_tlearner_tr_mct = list(tibble(y_tlearner_tr_mct = NA)),
                     yp_tlearner_tr_mclr = list(tibble(yp_tlearner_tr_mclr = NA)),
                     y_tlearner_tr_mclr = list(tibble(y_tlearner_tr_mclr = NA)),
                     yp_tlearner_tr_mclt = list(tibble(yp_tlearner_tr_mclt = NA)),
                     y_tlearner_tr_mclt = list(tibble(y_tlearner_tr_mclt = NA)),
                     yp_tlearner_tr_mcp = list(tibble(yp_tlearner_tr_mcp = NA)),
                     y_tlearner_tr_mcp = list(tibble(y_tlearner_tr_mcp = NA)),
                     tau_cforest_tr = list(tibble(tau_cforest_tr = NA)),
                     tau_slearner_tr = list(tibble(tau_slearner_tr = NA)),
                     tau_slearner_tr_mcr = list(tibble(tau_slearner_tr_mcr = NA)),
                     tau_drlearner_tr = list(tibble(tau_drlearner_tr = NA)),
                     tau_drlearner_tr_mcr = list(tibble(tau_drlearner_tr_mcr = NA)),
                     tau_drlearner_tr_mct = list(tibble(tau_drlearner_tr_mct = NA)),
                     tau_drlearner_tr_mcfr = list(tibble(tau_drlearner_tr_mcfr = NA)),
                     tau_drlearner_tr_mcft = list(tibble(tau_drlearner_tr_mcft = NA)),
                     tau_drlearner_tr_mclr = list(tibble(tau_drlearner_tr_mclr = NA)),
                     tau_drlearner_tr_mclt = list(tibble(tau_drlearner_tr_mclt = NA)),
                     tau_tlearner_tr = list(tibble(tau_tlearner_tr = NA)),
                     tau_tclearner_tr = list(tibble(tau_tclearner_tr = NA)),
                     tau_tlearner_tr_mcr = list(tibble(tau_tlearner_tr_mcr = NA)),
                     tau_tlearner_tr_mct = list(tibble(tau_tlearner_tr_mct = NA)),
                     tau_tlearner_tr_mclr = list(tibble(tau_tlearner_tr_mclr = NA)),
                     tau_tlearner_tr_mclt = list(tibble(tau_tlearner_tr_mclt = NA)),
                     tau_tlearner_tr_mcp = list(tibble(tau_tlearner_tr_mcp = NA)),
                     y_cforest_rct = list(tibble(y_cforest_rct = NA)),
                     y_cforestw_rct = list(tibble(y_cforestw_rct = NA)),
                     y_slearner_rct = list(tibble(y_slearner_rct = NA)),
                     y_slearnerw_rct = list(tibble(y_slearnerw_rct = NA)),
                     y_tlearner_rct = list(tibble(y_tlearner_rct = NA)),
                     y_tlearnerw_rct = list(tibble(y_tlearnerw_rct = NA)),
                     tau_cforest_rct = list(tibble(tau_cforest_rct = NA)),
                     tau_cforestw_rct = list(tibble(tau_cforestw_rct = NA)),
                     tau_slearner_rct = list(tibble(tau_slearner_rct = NA)),
                     tau_slearnerw_rct = list(tibble(tau_slearnerw_rct = NA)),
                     tau_drlearner_rct = list(tibble(tau_drlearner_rct = NA)),
                     tau_tlearner_rct = list(tibble(tau_tlearner_rct = NA)),
                     tau_tlearnerw_rct = list(tibble(tau_tlearnerw_rct = NA)))

test_res <- tibble()
test_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA, flag = NA, 
                    ps = list(tibble(ps = NA)),
                    KL = list(tibble(KL = NA)),
                    y = list(tibble(y = NA)),
                    ate = NA, tau = list(tibble(tau = NA)),
                    y_cforest_tr = list(tibble(y_cforest_tr = NA)),
                    y_slearner_tr = list(tibble(y_slearner_tr = NA)),
                    yp_slearner_tr_mcr = list(tibble(yp_slearner_tr_mcr = NA)),
                    y_slearner_tr_mcr = list(tibble(y_slearner_tr_mcr = NA)),
                    y_tlearner_tr = list(tibble(y_tlearner_tr = NA)),
                    y_tclearner_tr = list(tibble(y_tclearner_tr = NA)),
                    yp_tlearner_tr_mcr = list(tibble(yp_tlearner_tr_mcr = NA)),
                    y_tlearner_tr_mcr = list(tibble(y_tlearner_tr_mcr = NA)),
                    yp_tlearner_tr_mct = list(tibble(yp_tlearner_tr_mct = NA)),
                    y_tlearner_tr_mct = list(tibble(y_tlearner_tr_mct = NA)),
                    yp_tlearner_tr_mclr = list(tibble(yp_tlearner_tr_mclr = NA)),
                    y_tlearner_tr_mclr = list(tibble(y_tlearner_tr_mclr = NA)),
                    yp_tlearner_tr_mclt = list(tibble(yp_tlearner_tr_mclt = NA)),
                    y_tlearner_tr_mclt = list(tibble(y_tlearner_tr_mclt = NA)),
                    yp_tlearner_tr_mcp = list(tibble(yp_tlearner_tr_mcp = NA)),
                    y_tlearner_tr_mcp = list(tibble(y_tlearner_tr_mcp = NA)),
                    tau_cforest_tr = list(tibble(tau_cforest_tr = NA)),
                    tau_slearner_tr = list(tibble(tau_slearner_tr = NA)),
                    tau_slearner_tr_mcr = list(tibble(tau_slearner_tr_mcr = NA)),
                    tau_drlearner_tr = list(tibble(tau_drlearner_tr = NA)),
                    tau_drlearner_tr_mcr = list(tibble(tau_drlearner_tr_mcr = NA)),
                    tau_drlearner_tr_mct = list(tibble(tau_drlearner_tr_mct = NA)),
                    tau_drlearner_tr_mcfr = list(tibble(tau_drlearner_tr_mcfr = NA)),
                    tau_drlearner_tr_mcft = list(tibble(tau_drlearner_tr_mcft = NA)),
                    tau_drlearner_tr_mclr = list(tibble(tau_drlearner_tr_mclr = NA)),
                    tau_drlearner_tr_mclt = list(tibble(tau_drlearner_tr_mclt = NA)),
                    tau_tlearner_tr = list(tibble(tau_tlearner_tr = NA)),
                    tau_tclearner_tr = list(tibble(tau_tclearner_tr = NA)),
                    tau_tlearner_tr_mcr = list(tibble(tau_tlearner_tr_mcr = NA)),
                    tau_tlearner_tr_mct = list(tibble(tau_tlearner_tr_mct = NA)),
                    tau_tlearner_tr_mclr = list(tibble(tau_tlearner_tr_mclr = NA)),
                    tau_tlearner_tr_mclt = list(tibble(tau_tlearner_tr_mclt = NA)),
                    tau_tlearner_tr_mcp = list(tibble(tau_tlearner_tr_mcp = NA)),
                    y_cforest_rct = list(tibble(y_cforest_rct = NA)),
                    y_cforestw_rct = list(tibble(y_cforestw_rct = NA)),
                    y_slearner_rct = list(tibble(y_slearner_rct = NA)),
                    y_slearnerw_rct = list(tibble(y_slearnerw_rct = NA)),
                    y_tlearner_rct = list(tibble(y_tlearner_rct = NA)),
                    y_tlearnerw_rct = list(tibble(y_tlearnerw_rct = NA)),
                    tau_cforest_rct = list(tibble(tau_cforest_rct = NA)),
                    tau_cforestw_rct = list(tibble(tau_cforestw_rct = NA)),
                    tau_slearner_rct = list(tibble(tau_slearner_rct = NA)),
                    tau_slearnerw_rct = list(tibble(tau_slearnerw_rct = NA)),
                    tau_drlearner_rct = list(tibble(tau_drlearner_rct = NA)),
                    tau_tlearner_rct = list(tibble(tau_tlearner_rct = NA)),
                    tau_tlearnerw_rct = list(tibble(tau_tlearnerw_rct = NA)))

rct_res <- tibble()
rct_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA, 
                   y = list(tibble(y = NA)), 
                   yp = list(tibble(yp = NA)),
                   ate = NA, tau = list(tibble(tau = NA)),
                   tau_cforest_tr = list(tibble(tau_cforest_tr = NA)),
                   tau_slearner_tr = list(tibble(tau_slearner_tr = NA)),
                   tau_slearner_tr_mcr = list(tibble(tau_slearner_tr_mcr = NA)),
                   tau_drlearner_tr = list(tibble(tau_drlearner_tr = NA)),
                   tau_drlearner_tr_mcr = list(tibble(tau_drlearner_tr_mcr = NA)),
                   tau_drlearner_tr_mct = list(tibble(tau_drlearner_tr_mct = NA)),
                   tau_drlearner_tr_mcfr = list(tibble(tau_drlearner_tr_mcfr = NA)),
                   tau_drlearner_tr_mcft = list(tibble(tau_drlearner_tr_mcft = NA)),
                   tau_drlearner_tr_mclr = list(tibble(tau_drlearner_tr_mclr = NA)),
                   tau_drlearner_tr_mclt = list(tibble(tau_drlearner_tr_mclt = NA)),
                   tau_tlearner_tr = list(tibble(tau_tlearner_tr = NA)),
                   tau_tclearner_tr = list(tibble(tau_tclearner_tr = NA)),
                   tau_tlearner_tr_mcr = list(tibble(tau_tlearner_tr_mcr = NA)),
                   tau_tlearner_tr_mct = list(tibble(tau_tlearner_tr_mct = NA)),
                   tau_tlearner_tr_mclr = list(tibble(tau_tlearner_tr_mclr = NA)),
                   tau_tlearner_tr_mclt = list(tibble(tau_tlearner_tr_mclt = NA)),
                   tau_tlearner_tr_mcp = list(tibble(tau_tlearner_tr_mcp = NA)),
                   tau_cforest_rct = list(tibble(tau_cforest_rct = NA)),
                   tau_cforestw_rct = list(tibble(tau_cforestw_rct = NA)),
                   tau_slearner_rct = list(tibble(tau_slearner_rct = NA)),
                   tau_slearnerw_rct = list(tibble(tau_slearnerw_rct = NA)),
                   tau_tlearner_rct = list(tibble(tau_tlearner_rct = NA)),
                   tau_tlearnerw_rct = list(tibble(tau_tlearnerw_rct = NA)))

s_range <- seq(500, 5000, by = 1500) # n obs training set
e_range <- seq(0, 3, by = 0.25) # shift amplifier
n_reps <- 25 # n repetitions

scale <- function(x, label){
  (x - min(label))/(max(label) - min(label))
}

rev_scale <- function(preds, label){
  return(preds*(max(label) - min(label)) + min(label))
}

ridge <- LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 0, s = 1))
tree <- LearnerAuditorFitter$new(lrn("regr.rpart", maxdepth = 3))
subx <- SubpopAuditorFitter$new(list(function(data) {data[["x1"]] > mean(data[["x1"]])},
                                     function(data) {data[["x2"]] > mean(data[["x2"]])}))

## Simulation

## Generate RCT population
init_rct <- simulate_causal_experiment(ntrain = 100000, # n obs
                                       dim = 10, # n covars
                                       alpha = .1, # corr
                                       feat_distribution = "normal", 
                                       pscore = "rct5", 
                                       mu0 = "sparseLinearWeak",
                                       tau = "sparseLinearWeak",
                                       shiftfun = "SparseLogitLinear2") 

rct_pop <- tibble(init_rct$feat_tr, u = init_rct$u, ps = init_rct$Wp_tr, T = init_rct$W_tr, 
                  tau = init_rct$tauconf_tr, Y = init_rct$Yobsconf_tr,
                  shift = init_rct$shift_tr, shiftw = init_rct$shiftw_tr)

for(e in e_range) {
  
  ## Set external shift
  rct_pop$shift_s <- rct_pop$shiftw^e # weights source
  rct_pop$shift_t <- rct_pop$shiftw^-e # weights target
  
  train_temp$shift_e <- e
  test_temp$shift_e <- e
  rct_temp$shift_e <- e
  
  for(s in s_range) {
  
    ## Set training set size
    train_size <- s
    test_size <- 5000
    rct_size <- 500
    
    train_temp$trsize_s <- s
    test_temp$trsize_s <- s
    rct_temp$trsize_s <- s
    
    for(i in 1:n_reps) {
      
      train_temp$rep_i <- i
      test_temp$rep_i <- i
      rct_temp$rep_i <- i
      
      ## Sample train
      init_train <- simulate_causal_experiment(ntrain = train_size, # n obs
                                              dim = 10, # n covars
                                              alpha = .1, # corr
                                              feat_distribution = "normal", 
                                              pscore = "osConfounding2", 
                                              mu0 = "sparseLinearWeak",
                                              tau = "sparseLinearWeak",
                                              shiftfun = "SparseLogitLinear2") 
    
      train <- tibble(init_train$feat_tr, u = init_train$u, ps = init_train$Wp_tr, T = init_train$W_tr, 
                      tau = init_train$tauconf_tr, Y = init_train$Yobsconf_tr)
    
      ## Sample test
      init_test <- simulate_causal_experiment(ntrain = test_size, # n obs
                                              dim = 10, # n covars
                                              alpha = .1, # corr
                                              feat_distribution = "normal", 
                                              pscore = "osConfounding2", 
                                              mu0 = "sparseLinearWeak",
                                              tau = "sparseLinearWeak",
                                              shiftfun = "SparseLogitLinear2") 
    
      test <- tibble(init_test$feat_tr, u = init_test$u, ps = init_test$Wp_tr, T = init_test$W_tr, 
                     tau = init_test$tauconf_tr, Y = init_test$Yobsconf_tr)
    
      ## Sample RCT
      rct <- slice_sample(rct_pop, n = rct_size, weight_by = shift_s)
      
      ## Pre-process data
      train_temp$ps <- list(tibble(ps = train$ps)) # treatment propensities
      test_temp$ps <- list(tibble(ps = test$ps))
      train_temp$y <- list(tibble(y = train$Y)) # true Y
      test_temp$y <- list(tibble(y = test$Y))
      rct_temp$y <- list(tibble(y = rct$Y))
      train_temp$ate <- mean(train$tau) # true ATE
      test_temp$ate <- mean(test$tau)
      rct_temp$ate <- mean(rct$tau)
      train_temp$tau <- list(tibble(tau = train$tau)) # true tau
      test_temp$tau <- list(tibble(tau = test$tau))
      rct_temp$tau <- list(tibble(tau = rct$tau))

      train <- select(train, Y, T, x1:x10) # train data
      X_traint <- select(train, T, x1:x10) 
      Y_train <- train$Y
      Y_trains <- scale(Y_train, label = Y_train*3) # increase range of y
      T_train <- train$T
      rct <- select(rct, Y, T, x1:x10) # RCT data
      X_rctt <- select(rct, T, x1:x10)
      Y_rct <- rct$Y
      Y_rcts <- scale(Y_rct, label = Y_train*3)
      train_temp$flag <- ifelse(min(Y_rcts) < 0 | max(Y_rcts) > 1, 1, 0) # outside [0, 1]?
      test_temp$flag <- ifelse(min(Y_rcts) < 0 | max(Y_rcts) > 1, 1, 0)
      Y_rcts <- ifelse(Y_rcts < 0, 0, Y_rcts) # clip to [0, 1]
      Y_rcts <- ifelse(Y_rcts > 1, 1, Y_rcts)
      T_rct <- rct$T
      X_test <- select(test, x1:x10) # test data

      train_ut <- data.frame(X_traint[-T], T = 0) # Fix treated and untreated
      train_t <- data.frame(X_traint[-T], T = 1)
      test_ut <- data.frame(X_test, T = 0)
      test_t <- data.frame(X_test, T = 1)
      rct_ut <- data.frame(X_rctt[-T], T = 0)
      rct_t <- data.frame(X_rctt[-T], T = 1)
     
      ## KL divergence 
      
      Mref <- colMeans(X_test)
      Sref <- cov(X_test)
      Mtrain <- colMeans(X_rctt[,-1])
      Strain <- cov(X_rctt[,-1])
      train_temp$KL <- KLdiv(Mtrain, Mref, Strain, Sref, symmetric = F)
      test_temp$KL <- KLdiv(Mtrain, Mref, Strain, Sref, symmetric = F)
      
      ## Propensity score model - Train vs RCT
      
      stacked <- bind_rows(X_traint[-T], X_rctt[-T], .id = "rct")
      stacked$rct <- as.numeric(stacked$rct) - 1
      psm <- glm(rct ~ ., family = binomial, data = stacked)
      pscores <- predict(psm, newdata = X_rctt[-T], type = "response")
      pweights <- (1 - pscores) / pscores
      
      ## Train models - Train w. train data, post-process w. RCT data
      ### Causal Forest
      cforest_tr <- causal_forest(X_traint[-T], Y_train, T_train)

      train_temp$tau_cforest_tr <- list(tibble(tau_cforest_tr = predict(cforest_tr)$predictions))
      test_temp$tau_cforest_tr <- list(tibble(tau_cforest_tr = predict(cforest_tr, X_test)$predictions))

      ### S-learner
      slearner_tr <- ranger(y = Y_trains, x = X_traint)

      yp_sl_tr_t <- predict(slearner_tr, train_t)$predictions
      yp_sl_tr_ut <- predict(slearner_tr, train_ut)$predictions
      y_sl_tr_t <- rev_scale(yp_sl_tr_t, label = Y_train*3)
      y_sl_tr_ut <- rev_scale(yp_sl_tr_ut, label = Y_train*3)
      
      train_temp$y_slearner_tr <- list(tibble(y_slearner_tr = rev_scale(predict(slearner_tr, train)$predictions, label = Y_train*3)))
      train_temp$tau_slearner_tr <- list(tibble(tau_slearner_tr = y_sl_tr_t - y_sl_tr_ut))
      
      yp_sl_tr_t <- predict(slearner_tr, test_t)$predictions
      yp_sl_tr_ut <- predict(slearner_tr, test_ut)$predictions
      y_sl_tr_t <- rev_scale(yp_sl_tr_t, label = Y_train*3)
      y_sl_tr_ut <- rev_scale(yp_sl_tr_ut, label = Y_train*3)
      
      test_temp$y_slearner_tr <- list(tibble(y_slearner_tr = rev_scale(predict(slearner_tr, test)$predictions, label = Y_train*3)))
      test_temp$tau_slearner_tr <- list(tibble(tau_slearner_tr = y_sl_tr_t - y_sl_tr_ut))

      ### S-learner + MCBoost (ridge)
      init_preds_rf = function(data) {
        preds <- predict(slearner_tr, data)$predictions}
      slearner_tr_mc = MCBoost$new(init_predictor = init_preds_rf,
                                   auditor_fitter = "RidgeAuditorFitter",
                                 # iter_sampling = "bootstrap",
                                   max_iter = 10)
      slearner_tr_mc$multicalibrate(X_rctt, Y_rcts)
      
      train_temp$yp_slearner_tr_mcr <- list(tibble(yp_slearner_tr_mcr = slearner_tr_mc$predict_probs(train)))
      train_temp$y_slearner_tr_mcr <- list(tibble(y_slearner_tr_mcr = rev_scale(train_temp$yp_slearner_tr_mcr[[1]], label = Y_train*3)))
      test_temp$yp_slearner_tr_mcr <- list(tibble(yp_slearner_tr_mcr = slearner_tr_mc$predict_probs(test)))
      test_temp$y_slearner_tr_mcr <- list(tibble(y_slearner_tr_mcr = rev_scale(test_temp$yp_slearner_tr_mcr[[1]], label = Y_train*3)))

      yp_slearner_tr_t_mc <- slearner_tr_mc$predict_probs(train_t)
      yp_slearner_tr_ut_mc <- slearner_tr_mc$predict_probs(train_ut)
      y_slearner_tr_t_mc <- rev_scale(yp_slearner_tr_t_mc, label = Y_train*3)
      y_slearner_tr_ut_mc <- rev_scale(yp_slearner_tr_ut_mc, label = Y_train*3)
      train_temp$tau_slearner_tr_mcr <- list(tibble(tau_slearner_tr_mcr = y_slearner_tr_t_mc - y_slearner_tr_ut_mc))

      yp_slearner_tr_t_mc <- slearner_tr_mc$predict_probs(test_t)
      yp_slearner_tr_ut_mc <- slearner_tr_mc$predict_probs(test_ut)
      y_slearner_tr_t_mc <- rev_scale(yp_slearner_tr_t_mc, label = Y_train*3)
      y_slearner_tr_ut_mc <- rev_scale(yp_slearner_tr_ut_mc, label = Y_train*3)
      test_temp$tau_slearner_tr_mcr <- list(tibble(tau_slearner_tr_mcr = y_slearner_tr_t_mc - y_slearner_tr_ut_mc))

      ### DR-learner
      drlearner_tr <- dr_learner(X_traint[-T], Y_train, T_train, test, trunc = 0.025)
      
      train_temp$tau_drlearner_tr <- list(tibble(tau_drlearner_tr = drlearner_tr$tau.hat))
      test_temp$tau_drlearner_tr <- list(tibble(tau_drlearner_tr = drlearner_tr$tau.new))

      ### DR-learner + MCBoost (ridge)
      drlearner_tr_mcr <- dr_learnermc(X_traint[-T], X_traint, X_rctt, 
                                       Y_train, Y_trains, Y_rcts, 
                                       T_train, 
                                       test, trunc = 0.025)
      
      train_temp$tau_drlearner_tr_mcr <- list(tibble(tau_drlearner_tr_mcr = drlearner_tr_mcr$tau.hat))
      test_temp$tau_drlearner_tr_mcr <- list(tibble(tau_drlearner_tr_mcr = drlearner_tr_mcr$tau.new))

      ### DR-learner + MCBoost (tree)
      drlearner_tr_mct <- dr_learnermc(X_traint[-T], X_traint, X_rctt, 
                                        Y_train, Y_trains, Y_rcts, 
                                        T_train, 
                                        test, trunc = 0.025, auditor = "TreeAuditorFitter")
      
      train_temp$tau_drlearner_tr_mct <- list(tibble(tau_drlearner_tr_mct = drlearner_tr_mct$tau.hat))
      test_temp$tau_drlearner_tr_mct <- list(tibble(tau_drlearner_tr_mct = drlearner_tr_mct$tau.new))
      
      ### DR-learner + MCBoost (ridge)
      drlearner_tr_mcfr <- dr_learnermc2(X_traint[-T], X_traint, X_rctt, 
                                         Y_train, Y_trains, Y_rct, Y_rcts,
                                         T_train, T_rct, 
                                         test, trunc = 0.025, eta = 0.01)
      
      train_temp$tau_drlearner_tr_mcfr <- list(tibble(tau_drlearner_tr_mcfr = drlearner_tr_mcfr$tau.hat))
      test_temp$tau_drlearner_tr_mcfr <- list(tibble(tau_drlearner_tr_mcfr = drlearner_tr_mcfr$tau.new))
      
      ### DR-learner + MCBoost (tree) 
      drlearner_tr_mcft <- dr_learnermc2(X_traint[-T], X_traint, X_rctt, 
                                         Y_train, Y_trains, Y_rct, Y_rcts, 
                                         T_train, T_rct, 
                                         test, trunc = 0.025, eta = 0.01, auditor = "TreeAuditorFitter")
      
      train_temp$tau_drlearner_tr_mcft <- list(tibble(tau_drlearner_tr_mcft = drlearner_tr_mcft$tau.hat))
      test_temp$tau_drlearner_tr_mcft <- list(tibble(tau_drlearner_tr_mcft = drlearner_tr_mcft$tau.new))
      
      ### DR-learner + MCBoost (ridge)
      drlearner_tr_mclr <- dr_learnermc3(X_traint[-T], X_traint, X_rctt, 
                                         Y_train, Y_rct, 
                                         T_train, T_rct, 
                                         test, trunc = 0.025, eta = 0.01)
      
      train_temp$tau_drlearner_tr_mclr <- list(tibble(tau_drlearner_tr_mclr = drlearner_tr_mclr$tau.hat))
      test_temp$tau_drlearner_tr_mclr <- list(tibble(tau_drlearner_tr_mclr = drlearner_tr_mclr$tau.new))
      
      ### DR-learner + MCBoost (tree) 
      drlearner_tr_mclt <- dr_learnermc3(X_traint[-T], X_traint, X_rctt, 
                                         Y_train, Y_rct, 
                                         T_train, T_rct, 
                                         test, trunc = 0.025, eta = 0.01, auditor = "TreeAuditorFitter")
      
      train_temp$tau_drlearner_tr_mclt <- list(tibble(tau_drlearner_tr_mclt = drlearner_tr_mclt$tau.hat))
      test_temp$tau_drlearner_tr_mclt <- list(tibble(tau_drlearner_tr_mclt = drlearner_tr_mclt$tau.new))
      
      ### T-learner
      tlearner_tr_t <- ranger(y = Y_trains[X_traint$T == 1], 
                              x = X_traint[X_traint$T == 1, ])
      tlearner_tr_ut <- ranger(y = Y_trains[X_traint$T == 0], 
                               x = X_traint[X_traint$T == 0, ])

      yp_tl_tr_t <- predict(tlearner_tr_t, train_t)$predictions
      yp_tl_tr_ut <- predict(tlearner_tr_ut, train_ut)$predictions
      y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = Y_train*3)
      y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = Y_train*3)
      train_temp$y_tlearner_tr <- list(tibble(y_tlearner_tr = ifelse(train$T == 1, 
                                              rev_scale(predict(tlearner_tr_t, train)$predictions, label = Y_train*3), 
                                              rev_scale(predict(tlearner_tr_ut, train)$predictions, label = Y_train*3))))
      train_temp$tau_tlearner_tr <- list(tibble(tau_tlearner_tr = y_tl_tr_t - y_tl_tr_ut))
      
      yp_tl_tr_t <- predict(tlearner_tr_t, test_t)$predictions
      yp_tl_tr_ut <- predict(tlearner_tr_ut, test_ut)$predictions
      y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = Y_train*3)
      y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = Y_train*3)
      test_temp$y_tlearner_tr <- list(tibble(y_tlearner_tr = ifelse(test$T == 1, 
                                             rev_scale(predict(tlearner_tr_t, test)$predictions, label = Y_train*3), 
                                             rev_scale(predict(tlearner_tr_ut, test)$predictions, label = Y_train*3))))
      test_temp$tau_tlearner_tr <- list(tibble(tau_tlearner_tr = y_tl_tr_t - y_tl_tr_ut))
      
      yp_tl_tr_t <- predict(tlearner_tr_t, rct_t)$predictions
      yp_tl_tr_ut <- predict(tlearner_tr_ut, rct_ut)$predictions
      y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = Y_train*3)
      y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = Y_train*3)
      rct_temp$tau_tlearner_tr <- list(tibble(tau_tlearner_tr = y_tl_tr_t - y_tl_tr_ut))
      
      ### T-learner + MCBoost (ridge)
      init_preds = function(data) {
        preds <- predict(tlearner_tr_t, data)$predictions}
      tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                     auditor_fitter = "RidgeAuditorFitter",
                                     alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                     weight_degree = 2,
                                     eta = 0.5,
                                     max_iter = 5)
      tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])

      yp_tlearner_tr_t_mc_train <- tlearner_tr_t_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_t_mc_test <- tlearner_tr_t_mc$predict_probs(select(test, x1:x10, T))
      
      yp_tlearner_tr_t_mc_trt <- tlearner_tr_t_mc$predict_probs(train_t)
      y_tlearner_tr_t_mc_trt <- rev_scale(yp_tlearner_tr_t_mc_trt, label = Y_train*3)

      yp_tlearner_tr_t_mc_tst <- tlearner_tr_t_mc$predict_probs(test_t)
      y_tlearner_tr_t_mc_tst <- rev_scale(yp_tlearner_tr_t_mc_tst, label = Y_train*3)

      yp_tlearner_tr_t_mc_rct <- tlearner_tr_t_mc$predict_probs(rct_t)
      y_tlearner_tr_t_mc_rct <- rev_scale(yp_tlearner_tr_t_mc_rct, label = Y_train*3)

      init_preds = function(data) {
        preds <- predict(tlearner_tr_ut, data)$predictions}
      tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                      auditor_fitter = "RidgeAuditorFitter",
                                      alpha = 1e-06,
                                    # iter_sampling = "bootstrap",
                                      weight_degree = 2,
                                      eta = 0.5,
                                      max_iter = 5)
      tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])
    
      yp_tlearner_tr_ut_mc_train <- tlearner_tr_ut_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_ut_mc_test <- tlearner_tr_ut_mc$predict_probs(select(test, x1:x10, T))
      
      train_temp$yp_tlearner_tr_mcr <- list(tibble(yp_tlearner_tr_mcr = ifelse(train$T == 1, yp_tlearner_tr_t_mc_train, yp_tlearner_tr_ut_mc_train)))
      train_temp$y_tlearner_tr_mcr <- list(tibble(y_tlearner_tr_mcr = rev_scale(train_temp$yp_tlearner_tr_mcr[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_tr_mcr <- list(tibble(yp_tlearner_tr_mcr = ifelse(test$T == 1, yp_tlearner_tr_t_mc_test, yp_tlearner_tr_ut_mc_test)))
      test_temp$y_tlearner_tr_mcr <- list(tibble(y_tlearner_tr_mcr = rev_scale(test_temp$yp_tlearner_tr_mcr[[1]], label = Y_train*3)))
      
      yp_tlearner_tr_ut_mc_trt <- tlearner_tr_ut_mc$predict_probs(train_ut)
      y_tlearner_tr_ut_mc_trt <- rev_scale(yp_tlearner_tr_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_tr_mcr <- list(tibble(tau_tlearner_tr_mcr = y_tlearner_tr_t_mc_trt - y_tlearner_tr_ut_mc_trt))
      
      yp_tlearner_tr_ut_mc_tst <- tlearner_tr_ut_mc$predict_probs(test_ut)
      y_tlearner_tr_ut_mc_tst <- rev_scale(yp_tlearner_tr_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_tr_mcr <- list(tibble(tau_tlearner_tr_mcr = y_tlearner_tr_t_mc_tst - y_tlearner_tr_ut_mc_tst))

      yp_tlearner_tr_ut_mc_rct <- tlearner_tr_ut_mc$predict_probs(rct_ut)
      y_tlearner_tr_ut_mc_rct <- rev_scale(yp_tlearner_tr_ut_mc_rct, label = Y_train*3)
      rct_temp$tau_tlearner_tr_mcr <- list(tibble(tau_tlearner_tr_mcr = y_tlearner_tr_t_mc_rct - y_tlearner_tr_ut_mc_rct))
 
      ### T-learner + MCBoost (tree)
      init_preds = function(data) {
        preds <- predict(tlearner_tr_t, data)$predictions}
      tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                     auditor_fitter = "TreeAuditorFitter",
                                     alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                     weight_degree = 2,
                                     eta = 0.5,
                                     max_iter = 5)
      tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])
      
      yp_tlearner_tr_t_mc_train <- tlearner_tr_t_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_t_mc_test <- tlearner_tr_t_mc$predict_probs(select(test, x1:x10, T))
      
      yp_tlearner_tr_t_mc_trt <- tlearner_tr_t_mc$predict_probs(train_t)
      y_tlearner_tr_t_mc_trt <- rev_scale(yp_tlearner_tr_t_mc_trt, label = Y_train*3)
      
      yp_tlearner_tr_t_mc_tst <- tlearner_tr_t_mc$predict_probs(test_t)
      y_tlearner_tr_t_mc_tst <- rev_scale(yp_tlearner_tr_t_mc_tst, label = Y_train*3)
      
      yp_tlearner_tr_t_mc_rct <- tlearner_tr_t_mc$predict_probs(rct_t)
      y_tlearner_tr_t_mc_rct <- rev_scale(yp_tlearner_tr_t_mc_rct, label = Y_train*3)
      
      init_preds = function(data) {
        preds <- predict(tlearner_tr_ut, data)$predictions}
      tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                      auditor_fitter = "TreeAuditorFitter",
                                      alpha = 1e-06,
                                    # iter_sampling = "bootstrap",
                                      weight_degree = 2,
                                      eta = 0.5,
                                      max_iter = 5)
      tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])
      
      yp_tlearner_tr_ut_mc_train <- tlearner_tr_ut_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_ut_mc_test <- tlearner_tr_ut_mc$predict_probs(select(test, x1:x10, T))
      
      train_temp$yp_tlearner_tr_mct <- list(tibble(yp_tlearner_tr_mct = ifelse(train$T == 1, yp_tlearner_tr_t_mc_train, yp_tlearner_tr_ut_mc_train)))
      train_temp$y_tlearner_tr_mct <- list(tibble(y_tlearner_tr_mct = rev_scale(train_temp$yp_tlearner_tr_mct[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_tr_mct <- list(tibble(yp_tlearner_tr_mct = ifelse(test$T == 1, yp_tlearner_tr_t_mc_test, yp_tlearner_tr_ut_mc_test)))
      test_temp$y_tlearner_tr_mct <- list(tibble(y_tlearner_tr_mct = rev_scale(test_temp$yp_tlearner_tr_mct[[1]], label = Y_train*3)))
      
      yp_tlearner_tr_ut_mc_trt <- tlearner_tr_ut_mc$predict_probs(train_ut)
      y_tlearner_tr_ut_mc_trt <- rev_scale(yp_tlearner_tr_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_tr_mct <- list(tibble(tau_tlearner_tr_mct = y_tlearner_tr_t_mc_trt - y_tlearner_tr_ut_mc_trt))
      
      yp_tlearner_tr_ut_mc_tst <- tlearner_tr_ut_mc$predict_probs(test_ut)
      y_tlearner_tr_ut_mc_tst <- rev_scale(yp_tlearner_tr_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_tr_mct <- list(tibble(tau_tlearner_tr_mct = y_tlearner_tr_t_mc_tst - y_tlearner_tr_ut_mc_tst))
      
      yp_tlearner_tr_ut_mc_rct <- tlearner_tr_ut_mc$predict_probs(rct_ut)
      y_tlearner_tr_ut_mc_rct <- rev_scale(yp_tlearner_tr_ut_mc_rct, label = Y_train*3)
      rct_temp$tau_tlearner_tr_mct <- list(tibble(tau_tlearner_tr_mct = y_tlearner_tr_t_mc_rct - y_tlearner_tr_ut_mc_rct))
 
      ### T-learner + MCBoost (ridge max_iter 10)
      init_preds = function(data) {
        preds <- predict(tlearner_tr_t, data)$predictions}
      tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                     auditor_fitter = ridge,
                                     alpha = 1e-06,
                                     # iter_sampling = "bootstrap",
                                     weight_degree = 2,
                                     eta = 0.5,
                                     max_iter = 10)
      tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])
      
      yp_tlearner_tr_t_mc_train <- tlearner_tr_t_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_t_mc_test <- tlearner_tr_t_mc$predict_probs(select(test, x1:x10, T))
      
      yp_tlearner_tr_t_mc_trt <- tlearner_tr_t_mc$predict_probs(train_t)
      y_tlearner_tr_t_mc_trt <- rev_scale(yp_tlearner_tr_t_mc_trt, label = Y_train*3)
      
      yp_tlearner_tr_t_mc_tst <- tlearner_tr_t_mc$predict_probs(test_t)
      y_tlearner_tr_t_mc_tst <- rev_scale(yp_tlearner_tr_t_mc_tst, label = Y_train*3)
      
      yp_tlearner_tr_t_mc_rct <- tlearner_tr_t_mc$predict_probs(rct_t)
      y_tlearner_tr_t_mc_rct <- rev_scale(yp_tlearner_tr_t_mc_rct, label = Y_train*3)
      
      init_preds = function(data) {
        preds <- predict(tlearner_tr_ut, data)$predictions}
      tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                      auditor_fitter = ridge,
                                      alpha = 1e-06,
                                      # iter_sampling = "bootstrap",
                                      weight_degree = 2,
                                      eta = 0.5,
                                      max_iter = 10)
      tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])
      
      yp_tlearner_tr_ut_mc_train <- tlearner_tr_ut_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_ut_mc_test <- tlearner_tr_ut_mc$predict_probs(select(test, x1:x10, T))
      
      train_temp$yp_tlearner_tr_mclr <- list(tibble(yp_tlearner_tr_mclr = ifelse(train$T == 1, yp_tlearner_tr_t_mc_train, yp_tlearner_tr_ut_mc_train)))
      train_temp$y_tlearner_tr_mclr <- list(tibble(y_tlearner_tr_mclr = rev_scale(train_temp$yp_tlearner_tr_mclr[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_tr_mclr <- list(tibble(yp_tlearner_tr_mclr = ifelse(test$T == 1, yp_tlearner_tr_t_mc_test, yp_tlearner_tr_ut_mc_test)))
      test_temp$y_tlearner_tr_mclr <- list(tibble(y_tlearner_tr_mclr = rev_scale(test_temp$yp_tlearner_tr_mclr[[1]], label = Y_train*3)))
      
      yp_tlearner_tr_ut_mc_trt <- tlearner_tr_ut_mc$predict_probs(train_ut)
      y_tlearner_tr_ut_mc_trt <- rev_scale(yp_tlearner_tr_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_tr_mclr <- list(tibble(tau_tlearner_tr_mclr = y_tlearner_tr_t_mc_trt - y_tlearner_tr_ut_mc_trt))
      
      yp_tlearner_tr_ut_mc_tst <- tlearner_tr_ut_mc$predict_probs(test_ut)
      y_tlearner_tr_ut_mc_tst <- rev_scale(yp_tlearner_tr_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_tr_mclr <- list(tibble(tau_tlearner_tr_mclr = y_tlearner_tr_t_mc_tst - y_tlearner_tr_ut_mc_tst))
      
      yp_tlearner_tr_ut_mc_rct <- tlearner_tr_ut_mc$predict_probs(rct_ut)
      y_tlearner_tr_ut_mc_rct <- rev_scale(yp_tlearner_tr_ut_mc_rct, label = Y_train*3)
      rct_temp$tau_tlearner_tr_mclr <- list(tibble(tau_tlearner_tr_mclr = y_tlearner_tr_t_mc_rct - y_tlearner_tr_ut_mc_rct))

      ### T-learner + MCBoost (tree max_iter 10)
      init_preds = function(data) {
        preds <- predict(tlearner_tr_t, data)$predictions}
      tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                     auditor_fitter = tree,
                                     alpha = 1e-06,
                                     # iter_sampling = "bootstrap",
                                     weight_degree = 2,
                                     eta = 0.5,
                                     max_iter = 10)
      tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])
      
      yp_tlearner_tr_t_mc_train <- tlearner_tr_t_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_t_mc_test <- tlearner_tr_t_mc$predict_probs(select(test, x1:x10, T))
      
      yp_tlearner_tr_t_mc_trt <- tlearner_tr_t_mc$predict_probs(train_t)
      y_tlearner_tr_t_mc_trt <- rev_scale(yp_tlearner_tr_t_mc_trt, label = Y_train*3)
      
      yp_tlearner_tr_t_mc_tst <- tlearner_tr_t_mc$predict_probs(test_t)
      y_tlearner_tr_t_mc_tst <- rev_scale(yp_tlearner_tr_t_mc_tst, label = Y_train*3)
      
      yp_tlearner_tr_t_mc_rct <- tlearner_tr_t_mc$predict_probs(rct_t)
      y_tlearner_tr_t_mc_rct <- rev_scale(yp_tlearner_tr_t_mc_rct, label = Y_train*3)
      
      init_preds = function(data) {
        preds <- predict(tlearner_tr_ut, data)$predictions}
      tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                      auditor_fitter = tree,
                                      alpha = 1e-06,
                                      # iter_sampling = "bootstrap",
                                      weight_degree = 2,
                                      eta = 0.5,
                                      max_iter = 10)
      tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])
      
      yp_tlearner_tr_ut_mc_train <- tlearner_tr_ut_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_ut_mc_test <- tlearner_tr_ut_mc$predict_probs(select(test, x1:x10, T))
      
      train_temp$yp_tlearner_tr_mclt <- list(tibble(yp_tlearner_tr_mclt = ifelse(train$T == 1, yp_tlearner_tr_t_mc_train, yp_tlearner_tr_ut_mc_train)))
      train_temp$y_tlearner_tr_mclt <- list(tibble(y_tlearner_tr_mclt = rev_scale(train_temp$yp_tlearner_tr_mclt[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_tr_mclt <- list(tibble(yp_tlearner_tr_mclt = ifelse(test$T == 1, yp_tlearner_tr_t_mc_test, yp_tlearner_tr_ut_mc_test)))
      test_temp$y_tlearner_tr_mclt <- list(tibble(y_tlearner_tr_mclt = rev_scale(test_temp$yp_tlearner_tr_mclt[[1]], label = Y_train*3)))
      
      yp_tlearner_tr_ut_mc_trt <- tlearner_tr_ut_mc$predict_probs(train_ut)
      y_tlearner_tr_ut_mc_trt <- rev_scale(yp_tlearner_tr_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_tr_mclt <- list(tibble(tau_tlearner_tr_mclt = y_tlearner_tr_t_mc_trt - y_tlearner_tr_ut_mc_trt))
      
      yp_tlearner_tr_ut_mc_tst <- tlearner_tr_ut_mc$predict_probs(test_ut)
      y_tlearner_tr_ut_mc_tst <- rev_scale(yp_tlearner_tr_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_tr_mclt <- list(tibble(tau_tlearner_tr_mclt = y_tlearner_tr_t_mc_tst - y_tlearner_tr_ut_mc_tst))
      
      yp_tlearner_tr_ut_mc_rct <- tlearner_tr_ut_mc$predict_probs(rct_ut)
      y_tlearner_tr_ut_mc_rct <- rev_scale(yp_tlearner_tr_ut_mc_rct, label = Y_train*3)
      rct_temp$tau_tlearner_tr_mclt <- list(tibble(tau_tlearner_tr_mclt = y_tlearner_tr_t_mc_rct - y_tlearner_tr_ut_mc_rct))
      
      ### T-learner + MCBoost (subpop)
      init_preds = function(data) {
        preds <- predict(tlearner_tr_t, data)$predictions}
      tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                     auditor_fitter = subx,
                                     # alpha = 1e-06,
                                     # iter_sampling = "bootstrap",
                                     max_iter = 10)
      tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])
      
      yp_tlearner_tr_t_mc_train <- tlearner_tr_t_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_t_mc_test <- tlearner_tr_t_mc$predict_probs(select(test, x1:x10, T))
      
      yp_tlearner_tr_t_mc_trt <- tlearner_tr_t_mc$predict_probs(train_t)
      y_tlearner_tr_t_mc_trt <- rev_scale(yp_tlearner_tr_t_mc_trt, label = Y_train*3)
      
      yp_tlearner_tr_t_mc_tst <- tlearner_tr_t_mc$predict_probs(test_t)
      y_tlearner_tr_t_mc_tst <- rev_scale(yp_tlearner_tr_t_mc_tst, label = Y_train*3)
      
      yp_tlearner_tr_t_mc_rct <- tlearner_tr_t_mc$predict_probs(rct_t)
      y_tlearner_tr_t_mc_rct <- rev_scale(yp_tlearner_tr_t_mc_rct, label = Y_train*3)
      
      init_preds = function(data) {
        preds <- predict(tlearner_tr_ut, data)$predictions}
      tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                      auditor_fitter = subx,
                                      # alpha = 1e-06,
                                      # iter_sampling = "bootstrap",
                                      max_iter = 10)
      tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])
      
      yp_tlearner_tr_ut_mc_train <- tlearner_tr_ut_mc$predict_probs(select(train, -Y))
      yp_tlearner_tr_ut_mc_test <- tlearner_tr_ut_mc$predict_probs(select(test, x1:x10, T))
      
      train_temp$yp_tlearner_tr_mcp <- list(tibble(yp_tlearner_tr_mcp = ifelse(train$T == 1, yp_tlearner_tr_t_mc_train, yp_tlearner_tr_ut_mc_train)))
      train_temp$y_tlearner_tr_mcp <- list(tibble(y_tlearner_tr_mcp = rev_scale(train_temp$yp_tlearner_tr_mcp[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_tr_mcp <- list(tibble(yp_tlearner_tr_mcp = ifelse(test$T == 1, yp_tlearner_tr_t_mc_test, yp_tlearner_tr_ut_mc_test)))
      test_temp$y_tlearner_tr_mcp <- list(tibble(y_tlearner_tr_mcp = rev_scale(test_temp$yp_tlearner_tr_mcp[[1]], label = Y_train*3)))
      
      yp_tlearner_tr_ut_mc_trt <- tlearner_tr_ut_mc$predict_probs(train_ut)
      y_tlearner_tr_ut_mc_trt <- rev_scale(yp_tlearner_tr_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_tr_mcp <- list(tibble(tau_tlearner_tr_mcp = y_tlearner_tr_t_mc_trt - y_tlearner_tr_ut_mc_trt))
      
      yp_tlearner_tr_ut_mc_tst <- tlearner_tr_ut_mc$predict_probs(test_ut)
      y_tlearner_tr_ut_mc_tst <- rev_scale(yp_tlearner_tr_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_tr_mcp <- list(tibble(tau_tlearner_tr_mcp = y_tlearner_tr_t_mc_tst - y_tlearner_tr_ut_mc_tst))
      
      yp_tlearner_tr_ut_mc_rct <- tlearner_tr_ut_mc$predict_probs(rct_ut)
      y_tlearner_tr_ut_mc_rct <- rev_scale(yp_tlearner_tr_ut_mc_rct, label = Y_train*3)
      rct_temp$tau_tlearner_tr_mcp <- list(tibble(tau_tlearner_tr_mcp = y_tlearner_tr_t_mc_rct - y_tlearner_tr_ut_mc_rct))

      ### T-learner (grf)
      tclearner_tr_t <- regression_forest(Y = Y_trains[X_traint$T == 1], 
                                          X = X_traint[X_traint$T == 1, ])
      tclearner_tr_ut <- regression_forest(Y = Y_trains[X_traint$T == 0], 
                                           X = X_traint[X_traint$T == 0, ])
      
      yp_tl_tr_t <- predict(tclearner_tr_t, train_t)$predictions
      yp_tl_tr_ut <- predict(tclearner_tr_ut, train_ut)$predictions
      y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = Y_train*3)
      y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = Y_train*3)
      train_temp$y_tclearner_tr <- list(tibble(y_tclearner_tr = ifelse(train$T == 1, 
                                               rev_scale(predict(tclearner_tr_t, select(train, -Y))$predictions, label = Y_train*3), 
                                               rev_scale(predict(tclearner_tr_ut, select(train, -Y))$predictions, label = Y_train*3))))
      train_temp$tau_tclearner_tr <- list(tibble(tau_tclearner_tr = y_tl_tr_t - y_tl_tr_ut))
      
      yp_tl_tr_t <- predict(tclearner_tr_t, test_t)$predictions
      yp_tl_tr_ut <- predict(tclearner_tr_ut, test_ut)$predictions
      y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = Y_train*3)
      y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = Y_train*3)
      test_temp$y_tclearner_tr <- list(tibble(y_tclearner_tr = ifelse(test$T == 1, 
                                              rev_scale(predict(tclearner_tr_t, select(test, x1:x10, T))$predictions, label = Y_train*3), 
                                              rev_scale(predict(tclearner_tr_ut, select(test, x1:x10, T))$predictions, label = Y_train*3))))
      test_temp$tau_tclearner_tr <- list(tibble(tau_tclearner_tr = y_tl_tr_t - y_tl_tr_ut))
      
      yp_tl_tr_t <- predict(tclearner_tr_t, rct_t)$predictions
      yp_tl_tr_ut <- predict(tclearner_tr_ut, rct_ut)$predictions
      y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = Y_train*3)
      y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = Y_train*3)
      rct_temp$tau_tclearner_tr <- list(tibble(tau_tclearner_tr = y_tl_tr_t - y_tl_tr_ut))
      
      ## Train models - Train w. RCT data, post-process w. train data
      ### Causal Forest
      cforest_rct <- causal_forest(X_rctt[-T], Y_rct, T_rct)
    
      train_temp$tau_cforest_rct <- list(tibble(tau_cforest_rct = predict(cforest_rct, X_traint[-T])$predictions))
      test_temp$tau_cforest_rct <- list(tibble(tau_cforest_rct = predict(cforest_rct, X_test)$predictions))
    
      ### Causal Forest (weighted)
      cforestw_rct <- causal_forest(X_rctt[-T], Y_rct, T_rct, sample.weights = pweights)
      
      train_temp$tau_cforestw_rct <- list(tibble(tau_cforestw_rct = predict(cforestw_rct, X_traint[-T])$predictions))
      test_temp$tau_cforestw_rct <- list(tibble(tau_cforestw_rct = predict(cforestw_rct, X_test)$predictions))
      
      ### S-learner
      slearner_rct <- ranger(Y ~ ., 
                             data = rct)
    
      train_temp$y_slearner_rct <- list(tibble(y_slearner_rct = predict(slearner_rct, train)$predictions))
      test_temp$y_slearner_rct <- list(tibble(y_slearner_rct = predict(slearner_rct, test)$predictions))
    
      train_temp$tau_slearner_rct <- list(tibble(tau_slearner_rct = predict(slearner_rct, train_t)$predictions - predict(slearner_rct, train_ut)$predictions))
      test_temp$tau_slearner_rct <- list(tibble(tau_slearner_rct = predict(slearner_rct, test_t)$predictions - predict(slearner_rct, test_ut)$predictions))

      ### S-learner (weighted)
      slearnerw_rct <- ranger(Y ~ ., 
                             case.weights = pweights,
                             data = rct)

      train_temp$y_slearnerw_rct <- list(tibble(y_slearnerw_rct = predict(slearnerw_rct, train)$predictions))
      test_temp$y_slearnerw_rct <- list(tibble(y_slearnerw_rct = predict(slearnerw_rct, test)$predictions))
      
      train_temp$tau_slearnerw_rct <- list(tibble(tau_slearnerw_rct = predict(slearnerw_rct, train_t)$predictions - predict(slearnerw_rct, train_ut)$predictions))
      test_temp$tau_slearnerw_rct <- list(tibble(tau_slearnerw_rct = predict(slearnerw_rct, test_t)$predictions - predict(slearnerw_rct, test_ut)$predictions))
      
      ### DR-learner
      drlearner_rct <- dr_learner(X_rctt[-T], Y_rct, T_rct, train)
      train_temp$tau_drlearner_rct <- list(tibble(tau_drlearner_rct = drlearner_rct$tau.new))
      
      drlearner_rct <- dr_learner(X_rctt[-T], Y_rct, T_rct, test)
      test_temp$tau_drlearner_rct <- list(tibble(tau_drlearner_rct = drlearner_rct$tau.new))

      ### T-learner
      tlearner_rct_t <- ranger(Y ~ . -T, 
                               data = rct[rct$T == 1, ])
      tlearner_rct_ut <- ranger(Y ~ . -T,
                                data = rct[rct$T == 0, ])
    
      train_temp$y_tlearner_rct <- list(tibble(y_tlearner_rct = ifelse(train$T == 1, 
                                               predict(tlearner_rct_t, train)$predictions,
                                               predict(tlearner_rct_ut, train)$predictions)))
      test_temp$y_tlearner_rct <- list(tibble(y_tlearner_rct = ifelse(test$T == 1, 
                                              predict(tlearner_rct_t, test)$predictions,
                                              predict(tlearner_rct_ut, test)$predictions)))
    
      train_temp$tau_tlearner_rct <- list(tibble(tau_tlearner_rct = predict(tlearner_rct_t, train_t)$predictions - predict(tlearner_rct_ut, train_ut)$predictions))
      test_temp$tau_tlearner_rct <- list(tibble(tau_tlearner_rct = predict(tlearner_rct_t, test_t)$predictions - predict(tlearner_rct_ut, test_ut)$predictions))
      rct_temp$tau_tlearner_rct <- list(tibble(tau_tlearner_rct = predict(tlearner_rct_t, rct_t)$predictions - predict(tlearner_rct_ut, rct_ut)$predictions))
      
      ### T-learner (weighted)
      tlearnerw_rct_t <- ranger(Y ~ . -T, 
                               case.weights = pweights[rct$T == 1],
                               data = rct[rct$T == 1, ])
      tlearnerw_rct_ut <- ranger(Y ~ . -T, 
                                case.weights = pweights[rct$T == 0],
                                data = rct[rct$T == 0, ])
      
      train_temp$y_tlearnerw_rct <- list(tibble(y_tlearnerw_rct = ifelse(train$T == 1, 
                                                predict(tlearnerw_rct_t, train)$predictions,
                                                predict(tlearnerw_rct_ut, train)$predictions)))
      test_temp$y_tlearnerw_rct <- list(tibble(y_tlearnerw_rct = ifelse(test$T == 1, 
                                               predict(tlearnerw_rct_t, test)$predictions,
                                               predict(tlearnerw_rct_ut, test)$predictions)))

      train_temp$tau_tlearnerw_rct <- list(tibble(tau_tlearnerw_rct = predict(tlearnerw_rct_t, train_t)$predictions - predict(tlearnerw_rct_ut, train_ut)$predictions))
      test_temp$tau_tlearnerw_rct <- list(tibble(tau_tlearnerw_rct = predict(tlearnerw_rct_t, test_t)$predictions - predict(tlearnerw_rct_ut, test_ut)$predictions))
      rct_temp$tau_tlearnerw_rct <- list(tibble(tau_tlearnerw_rct = predict(tlearnerw_rct_t, rct_t)$predictions - predict(tlearnerw_rct_ut, rct_ut)$predictions))
      
      ## Save results
      train_res <- rbind(train_res, train_temp)
      test_res <- rbind(test_res, test_temp)
      rct_res <- rbind(rct_res, rct_temp)
      
      print(paste("e =", e, "s =", s, "i =", i))
    }
  }
}

## Combine results

train_r <- train_res %>% 
  unnest() %>%
  group_by(rep_i, trsize_s, shift_e) %>%
  summarise(flag = mean(flag),
            mean_ps = mean(ps),
            mean_KL = mean(KL),
            mean_y = mean(y),
            bias_cforestTR = mean(ate - mean(tau_cforest_tr)),
            bias_slearnerTR = mean(ate - mean(tau_slearner_tr)),
            bias_slearnerTRmcr = mean(ate - mean(tau_slearner_tr_mcr)),
            bias_drlearnerTR = mean(ate - mean(tau_drlearner_tr)),
            bias_drlearnerTRmcr = mean(ate - mean(tau_drlearner_tr_mcr)),
            bias_drlearnerTRmct = mean(ate - mean(tau_drlearner_tr_mct)),
            bias_drlearnerTRmcfr = mean(ate - mean(tau_drlearner_tr_mcfr)),
            bias_drlearnerTRmcft = mean(ate - mean(tau_drlearner_tr_mcft)),
            bias_drlearnerTRmclr = mean(ate - mean(tau_drlearner_tr_mclr)),
            bias_drlearnerTRmclt = mean(ate - mean(tau_drlearner_tr_mclt)),
            bias_tlearnerTR = mean(ate - mean(tau_tlearner_tr)),
            bias_tlearnerTRmcr = mean(ate - mean(tau_tlearner_tr_mcr)),
            bias_tlearnerTRmct = mean(ate - mean(tau_tlearner_tr_mct)),
            bias_tlearnerTRmclr = mean(ate - mean(tau_tlearner_tr_mclr)),
            bias_tlearnerTRmclt = mean(ate - mean(tau_tlearner_tr_mclt)),
            bias_tlearnerTRmcp = mean(ate - mean(tau_tlearner_tr_mcp)),
            bias_tclearnerTR = mean(ate - mean(tau_tclearner_tr)),
            mse_cforestTR = mean((tau - tau_cforest_tr)^2),
            mse_slearnerTR = mean((tau - tau_slearner_tr)^2),
            mse_slearnerTRmcr = mean((tau - tau_slearner_tr_mcr)^2),
            mse_drlearnerTR = mean((tau - tau_drlearner_tr)^2),
            mse_drlearnerTRmcr = mean((tau - tau_drlearner_tr_mcr)^2),
            mse_drlearnerTRmct = mean((tau - tau_drlearner_tr_mct)^2),
            mse_drlearnerTRmcfr = mean((tau - tau_drlearner_tr_mcfr)^2),
            mse_drlearnerTRmcft = mean((tau - tau_drlearner_tr_mcft)^2),
            mse_drlearnerTRmclr = mean((tau - tau_drlearner_tr_mclr)^2),
            mse_drlearnerTRmclt = mean((tau - tau_drlearner_tr_mclt)^2),
            mse_tlearnerTR = mean((tau - tau_tlearner_tr)^2),
            mse_tlearnerTRmcr = mean((tau - tau_tlearner_tr_mcr)^2),
            mse_tlearnerTRmct = mean((tau - tau_tlearner_tr_mct)^2),
            mse_tlearnerTRmclr = mean((tau - tau_tlearner_tr_mclr)^2),
            mse_tlearnerTRmclt = mean((tau - tau_tlearner_tr_mclt)^2),
            mse_tlearnerTRmcp = mean((tau - tau_tlearner_tr_mcp)^2),
            mse_tclearnerTR = mean((tau - tau_tclearner_tr)^2),
            prerr_slearnerRCT = mean((y - y_slearner_rct)^2),
            prerr_tlearnerRCT = mean((y - y_tlearner_rct)^2),
            bias_cforestRCT = mean(ate - mean(tau_cforest_rct)),
            bias_cforestwRCT = mean(ate - mean(tau_cforestw_rct)),
            bias_slearnerRCT = mean(ate - mean(tau_slearner_rct)),
            bias_slearnerwRCT = mean(ate - mean(tau_slearnerw_rct)),
            bias_drlearnerRCT = mean(ate - mean(tau_drlearner_rct)),
            bias_tlearnerRCT = mean(ate - mean(tau_tlearner_rct)),
            bias_tlearnerwRCT = mean(ate - mean(tau_tlearnerw_rct)),
            mse_cforestRCT = mean((tau - tau_cforest_rct)^2),
            mse_cforestwRCT = mean((tau - tau_cforestw_rct)^2),
            mse_slearnerRCT = mean((tau - tau_slearner_rct)^2),
            mse_slearnerwRCT = mean((tau - tau_slearnerw_rct)^2),
            mse_drlearnerRCT = mean((tau - tau_drlearner_rct)^2),
            mse_tlearnerRCT = mean((tau - tau_tlearner_rct)^2),
            mse_tlearnerwRCT = mean((tau - tau_tlearnerw_rct)^2))

test_r <- test_res %>% 
  unnest() %>%
  group_by(rep_i, trsize_s, shift_e) %>%
  summarise(flag = mean(flag),
            mean_ps = mean(ps),
            mean_KL = mean(KL),
            mean_y = mean(y),
            bias_cforestTR = mean(ate - mean(tau_cforest_tr)),
            bias_slearnerTR = mean(ate - mean(tau_slearner_tr)),
            bias_slearnerTRmcr = mean(ate - mean(tau_slearner_tr_mcr)),
            bias_drlearnerTR = mean(ate - mean(tau_drlearner_tr)),
            bias_drlearnerTRmcr = mean(ate - mean(tau_drlearner_tr_mcr)),
            bias_drlearnerTRmct = mean(ate - mean(tau_drlearner_tr_mct)),
            bias_drlearnerTRmcfr = mean(ate - mean(tau_drlearner_tr_mcfr)),
            bias_drlearnerTRmcft = mean(ate - mean(tau_drlearner_tr_mcft)),
            bias_drlearnerTRmclr = mean(ate - mean(tau_drlearner_tr_mclr)),
            bias_drlearnerTRmclt = mean(ate - mean(tau_drlearner_tr_mclt)),
            bias_tlearnerTR = mean(ate - mean(tau_tlearner_tr)),
            bias_tlearnerTRmcr = mean(ate - mean(tau_tlearner_tr_mcr)),
            bias_tlearnerTRmct = mean(ate - mean(tau_tlearner_tr_mct)),
            bias_tlearnerTRmclr = mean(ate - mean(tau_tlearner_tr_mclr)),
            bias_tlearnerTRmclt = mean(ate - mean(tau_tlearner_tr_mclt)),
            bias_tlearnerTRmcp = mean(ate - mean(tau_tlearner_tr_mcp)),
            bias_tclearnerTR = mean(ate - mean(tau_tclearner_tr)),
            mse_cforestTR = mean((tau - tau_cforest_tr)^2),
            mse_slearnerTR = mean((tau - tau_slearner_tr)^2),
            mse_slearnerTRmcr = mean((tau - tau_slearner_tr_mcr)^2),
            mse_drlearnerTR = mean((tau - tau_drlearner_tr)^2),
            mse_drlearnerTRmcr = mean((tau - tau_drlearner_tr_mcr)^2),
            mse_drlearnerTRmct = mean((tau - tau_drlearner_tr_mct)^2),
            mse_drlearnerTRmcfr = mean((tau - tau_drlearner_tr_mcfr)^2),
            mse_drlearnerTRmcft = mean((tau - tau_drlearner_tr_mcft)^2),
            mse_drlearnerTRmclr = mean((tau - tau_drlearner_tr_mclr)^2),
            mse_drlearnerTRmclt = mean((tau - tau_drlearner_tr_mclt)^2),
            mse_tlearnerTR = mean((tau - tau_tlearner_tr)^2),
            mse_tlearnerTRmcr = mean((tau - tau_tlearner_tr_mcr)^2),
            mse_tlearnerTRmct = mean((tau - tau_tlearner_tr_mct)^2),
            mse_tlearnerTRmclr = mean((tau - tau_tlearner_tr_mclr)^2),
            mse_tlearnerTRmclt = mean((tau - tau_tlearner_tr_mclt)^2),
            mse_tlearnerTRmcp = mean((tau - tau_tlearner_tr_mcp)^2),
            mse_tclearnerTR = mean((tau - tau_tclearner_tr)^2),
            prerr_slearnerRCT = mean((y - y_slearner_rct)^2),
            prerr_tlearnerRCT = mean((y - y_tlearner_rct)^2),
            bias_cforestRCT = mean(ate - mean(tau_cforest_rct)),
            bias_cforestwRCT = mean(ate - mean(tau_cforestw_rct)),
            bias_slearnerRCT = mean(ate - mean(tau_slearner_rct)),
            bias_slearnerwRCT = mean(ate - mean(tau_slearnerw_rct)),
            bias_drlearnerRCT = mean(ate - mean(tau_drlearner_rct)),
            bias_tlearnerRCT = mean(ate - mean(tau_tlearner_rct)),
            bias_tlearnerwRCT = mean(ate - mean(tau_tlearnerw_rct)),
            mse_cforestRCT = mean((tau - tau_cforest_rct)^2),
            mse_cforestwRCT = mean((tau - tau_cforestw_rct)^2),
            mse_slearnerRCT = mean((tau - tau_slearner_rct)^2),
            mse_slearnerwRCT = mean((tau - tau_slearnerw_rct)^2),
            mse_drlearnerRCT = mean((tau - tau_drlearner_rct)^2),
            mse_tlearnerRCT = mean((tau - tau_tlearner_rct)^2),
            mse_tlearnerwRCT = mean((tau - tau_tlearnerw_rct)^2))

rct_r <- rct_res %>% 
  unnest() %>%
  group_by(rep_i, trsize_s, shift_e) %>% 
  summarise(d2bias_tlearnerTR = mean((tau - tau_tlearner_tr)*tau_tlearner_tr),
            d2bias_tlearnerTRmcr = mean((tau - tau_tlearner_tr_mcr)*tau_tlearner_tr_mcr),
            d2bias_tlearnerTRmct = mean((tau - tau_tlearner_tr_mct)*tau_tlearner_tr_mct),
            d2bias_tlearnerTRmcp = mean((tau - tau_tlearner_tr_mcp)*tau_tlearner_tr_mcp),
            d2bias_tlearnerRCT = mean((tau - tau_tlearner_rct)*tau_tlearner_rct),
            d2bias_tlearnerwRCT = mean((tau - tau_tlearnerw_rct)*tau_tlearnerw_rct))

train_r_long <- train_r %>%
  select(rep_i, trsize_s, shift_e, flag, bias_cforestTR:mse_tlearnerwRCT) %>%
  pivot_longer(cols = bias_cforestTR:mse_tlearnerwRCT,
               names_to = c("Metric", "Method"),
               names_sep = "_")
  
test_r_long <- test_r %>%
  select(rep_i, trsize_s, shift_e, flag, bias_cforestTR:mse_tlearnerwRCT) %>%
  pivot_longer(cols = bias_cforestTR:mse_tlearnerwRCT,
               names_to = c("Metric", "Method"),
               names_sep = "_")

rct_r_long <- rct_r %>%
  select(rep_i, trsize_s, shift_e, d2bias_tlearnerTR:d2bias_tlearnerwRCT) %>%
  pivot_longer(cols = d2bias_tlearnerTR:d2bias_tlearnerwRCT,
               names_to = c("Metric", "Method"),
               names_sep = "_")

save(train_r, test_r, rct_r,
     train_r_long, test_r_long, rct_r_long,
     file = "simu.RData")

## Evaluation

### KL divergence

kl_tab <- test_r %>%
  filter(shift_e <= 2) %>%
  group_by(trsize_s, shift_e) %>%
  summarize(kl = mean(mean_KL)) %>%
  mutate(across(kl, round, 2))

kl_ftab <- knitr::kable(kl_tab, format = 'latex')
writeLines(kl_ftab, 's1b_kl.tex')

### Bias and MSE

bias_tab <- test_r %>%
  filter(shift_e <= 2) %>%
  group_by(trsize_s, shift_e) %>%
  summarize(bias_cforestTR = mean(bias_cforestTR), 
            bias_slearnerTR = mean(bias_slearnerTR),
            bias_drlearnerTR = mean(bias_drlearnerTR), 
            bias_drlearnerTRmcfr = mean(bias_drlearnerTRmcfr),
            bias_drlearnerTRmcft = mean(bias_drlearnerTRmcft),
            bias_tlearnerTR = mean(bias_tlearnerTR), 
            bias_tlearnerTRmcr = mean(bias_tlearnerTRmcr),
            bias_tlearnerTRmct = mean(bias_tlearnerTRmct),
            bias_cforestRCT = mean(bias_cforestRCT),
            bias_cforestwRCT = mean(bias_cforestwRCT), 
            bias_slearnerRCT = mean(bias_slearnerRCT),
            bias_slearnerwRCT = mean(bias_slearnerwRCT),
            bias_drlearnerRCT = mean(bias_drlearnerRCT),
            bias_tlearnerRCT = mean(bias_tlearnerRCT),
            bias_tlearnerwRCT = mean(bias_tlearnerwRCT)) %>%
  mutate(across(bias_cforestTR:bias_tlearnerwRCT, round, 2))

cols <- seq_len(ncol(bias_tab))
bias_ftab <- data.frame(matrix(NA, nrow = nrow(bias_tab), ncol = ncol(bias_tab)))
names(bias_ftab) <- names(bias_tab)

for(i in 1:nrow(bias_tab)) {
  minf <- order(abs(as.numeric(bias_tab[i, 3:ncol(bias_tab)])))[1]
  mins <- order(abs(as.numeric(bias_tab[i, 3:ncol(bias_tab)])))[2]
  bias_ftab[i, ] <- cell_spec(bias_tab[i, ], 'latex', 
                              bold = cols == minf+2,
                              italic = cols == mins+2)
}

bias_ftab <- knitr::kable(bias_ftab, format = 'latex')
writeLines(bias_ftab, 's1d_2_bias.tex')

mse_tab <- test_r %>%
  filter(shift_e <= 2) %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mse_cforestTR = mean(mse_cforestTR), 
            mse_slearnerTR = mean(mse_slearnerTR),
            mse_drlearnerTR = mean(mse_drlearnerTR),
            mse_drlearnerTRmcfr = mean(mse_drlearnerTRmcfr),
            mse_drlearnerTRmcft = mean(mse_drlearnerTRmcft),
            mse_tlearnerTR = mean(mse_tlearnerTR), 
            mse_tlearnerTRmcr = mean(mse_tlearnerTRmcr),
            mse_tlearnerTRmct = mean(mse_tlearnerTRmct),
            mse_cforestRCT = mean(mse_cforestRCT), 
            mse_cforestwRCT = mean(mse_cforestwRCT),
            mse_slearnerRCT = mean(mse_slearnerRCT),
            mse_slearnerwRCT = mean(mse_slearnerwRCT),
            mse_tlearnerRCT = mean(mse_tlearnerRCT),
            mse_drlearnerRCT = mean(mse_drlearnerRCT),
            mse_tlearnerwRCT = mean(mse_tlearnerwRCT)) %>%
  mutate(across(mse_cforestTR:mse_tlearnerwRCT, round, 2))

cols <- seq_len(ncol(mse_tab))
mse_ftab <- data.frame(matrix(NA, nrow = nrow(mse_tab), ncol = ncol(mse_tab)))
names(mse_ftab) <- names(mse_tab)

for(i in 1:nrow(mse_tab)) {
  minf <- order(as.numeric(mse_tab[i, 3:ncol(mse_tab)]))[1]
  mins <- order(as.numeric(mse_tab[i, 3:ncol(mse_tab)]))[2]
  mse_ftab[i, ] <- cell_spec(mse_tab[i, ], 'latex', 
                              bold = cols == minf+2,
                              italic = cols == mins+2)
}

mse_ftab <- knitr::kable(mse_ftab, format = 'latex')
writeLines(mse_ftab, 's1d_2_mse.tex')

train_r_long$Method <- recode_factor(train_r_long$Method, 
                                     "cforestTR" = "CForest-OS", 
                                     "slearnerTR" = "S-learner-OS",
                                     "slearnerTRmcr" = "slearnerTRmcr",
                                     "drlearnerTR" = "DR-learner-OS",
                                     "drlearnerTRmcfr" = "DR-learner-MC-Ridge",
                                     "drlearnerTRmcft" = "DR-learner-MC-Tree",
                                     "tlearnerTR" = "T-learner-OS",
                                     "tlearnerTRmcr" = "T-learner-MC-Ridge", 
                                     "tlearnerTRmct" = "T-learner-MC-Tree", 
                                     "tlearnerTRmcp" = "tlearnerTRmcp",
                                     "tclearnerTR" = "tclearnerTR", 
                                     "cforestRCT" = "CForest-CT", 
                                     "cforestwRCT" = "CForest-wCT", 
                                     "slearnerRCT" = "S-learner-CT", 
                                     "slearnerwRCT" = "S-learner-wCT", 
                                     "drlearnerRCT" = "DR-learner-CT",
                                     "tlearnerRCT" = "T-learner-CT",
                                     "tlearnerwRCT" = "T-learner-wCT")
test_r_long$Method <- recode_factor(test_r_long$Method, 
                                    "cforestTR" = "CForest-OS", 
                                    "slearnerTR" = "S-learner-OS",
                                    "slearnerTRmcr" = "slearnerTRmcr",
                                    "drlearnerTR" = "DR-learner-OS",
                                    "drlearnerTRmcfr" = "DR-learner-MC-Ridge",
                                    "drlearnerTRmcft" = "DR-learner-MC-Tree",
                                    "tlearnerTR" = "T-learner-OS",
                                    "tlearnerTRmcr" = "T-learner-MC-Ridge", 
                                    "tlearnerTRmct" = "T-learner-MC-Tree", 
                                    "tlearnerTRmcp" = "tlearnerTRmcp",
                                    "tclearnerTR" = "tclearnerTR", 
                                    "cforestRCT" = "CForest-CT", 
                                    "cforestwRCT" = "CForest-wCT", 
                                    "slearnerRCT" = "S-learner-CT", 
                                    "slearnerwRCT" = "S-learner-wCT", 
                                    "drlearnerRCT" = "DR-learner-CT",
                                    "tlearnerRCT" = "T-learner-CT",
                                    "tlearnerwRCT" = "T-learner-wCT")

### Plots

train_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1d_1_train_pred-error.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1d_1_train_ate-bias.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1d_1_train_cate-mse.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1d_1_test_pred-error.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1d_1_test_ate-bias.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1d_1_test_cate-mse.pdf", width = 10, height = 6)

test_r_long %>%
  mutate(method = fct_rev(Method)) %>%
 # filter(flag == 0) %>%
  filter(Metric == "prerr") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "Prediction MSE") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e)) +
  coord_flip() +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank())

ggsave("s1d_1_test_pred-error-box.pdf", width = 10, height = 7)

shift_labs <- c("no shift", "moderate shift", "strong shift")
names(shift_labs) <- c("0", "1", "2")

test_r_long %>%
  mutate(method = fct_rev(Method)) %>%
 # filter(flag == 0) %>%
 # filter(!Method %in% c("S-learner-OS", "S-learner-CT", "S-learner-wCT")) %>%
  filter(!Method %in% c("slearnerTRmcr", "slearnerTRmct", "slearnerTRmcp")) %>%
  filter(!Method %in% c("tlearnerTRmclr", "tlearnerTRmclt", "tlearnerTRmcp", "tclearnerTR",
                        "drlearnerTRmcr", "drlearnerTRmct", "drlearnerTRmclr", "drlearnerTRmclt")) %>%
  filter(Metric == "bias") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "Bias (ATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e),
             labeller = labeller(shift_e = shift_labs)) +
 # scale_colour_brewer(palette = "Paired") +
  scale_colour_manual(values = c(RColorBrewer::brewer.pal(12, "Paired"),
                                 RColorBrewer::brewer.pal(3, "Dark2"))) +
  coord_flip(ylim = c(-20, 20)) +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.title = element_blank(),
        text = element_text(size = 13))

ggsave("s1d_2_test_ate-bias-box.pdf", width = 10, height = 7)

test_r_long %>%
  mutate(method = fct_rev(Method)) %>%
 # filter(flag == 0) %>%
 # filter(!Method %in% c("S-learner-OS", "S-learner-CT", "S-learner-wCT")) %>%
  filter(!Method %in% c("slearnerTRmcr", "slearnerTRmct", "slearnerTRmcp")) %>%
  filter(!Method %in% c("tlearnerTRmclr", "tlearnerTRmclt", "tlearnerTRmcp", "tclearnerTR",
                        "drlearnerTRmcr", "drlearnerTRmct", "drlearnerTRmclr", "drlearnerTRmclt")) %>%
  filter(Metric == "mse") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "MSE (CATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e), scales = "free",
             labeller = labeller(shift_e = shift_labs)) +
# scale_colour_brewer(palette = "Paired") +
  scale_colour_manual(values = c(RColorBrewer::brewer.pal(12, "Paired"),
                                 RColorBrewer::brewer.pal(3, "Dark2"))) +
  coord_flip() +
  facetted_pos_scales(y = list(shift_e == 0 ~ scale_y_continuous(limits = c(0, 200)),
                               shift_e == 1 ~ scale_y_continuous(limits = c(0, 200)),
                               shift_e == 2 ~ scale_y_continuous(limits = c(0, 350)))) +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.title = element_blank(),
        text = element_text(size = 13))

ggsave("s1d_2_test_cate-mse-box.pdf", width = 10, height = 7)

## Aggregated bar plots

library(plyr)

data_summary <- function(data, varname, groupnames){
  summary_func <- function(x, col){
    c(mean = mean(x[[col]], na.rm=T), sd = sd(x[[col]], na.rm=T))
  }
  data_sum <- ddply(data, groupnames, .fun = summary_func, varname)
  data_sum <- rename(data_sum, c("mean" = varname))
  return(data_sum)
}

d_sum <- data_summary(test_r_long, varname = "value", 
                      groupnames = c("shift_e", "trsize_s", "Metric", "Method"))

shift_labs <- c("no shift", "moderate shift", "strong shift")
names(shift_labs) <- c("0", "1", "2")

d_sum %>%
  mutate(trsize_s = factor(trsize_s),
         Method = recode_factor(Method, 
                                "cforestTR" = "CForest-OS", 
                                "tlearnerTR" = "T-learner-OS", 
                                "tlearnerTRmcr" = "T-learner-MC-Ridge",
                                "drlearnerTR" = "DR-learner-OS",
                                "drlearnerTRmcfr" = "DR-learner-MC-Ridge",
                                "cforestRCT" = "CForest-CT")) %>%
  filter(Metric == "mse") %>%
  filter(Method %in% c("CForest-OS", "T-learner-OS", "T-learner-MC-Ridge",
                       "DR-learner-OS", "DR-learner-MC-Ridge", "CForest-CT")) %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(x = trsize_s, y = value, fill = Method)) + 
  geom_bar(stat = "identity", color = "black", position = position_dodge()) +
  geom_errorbar(aes(ymin = value-sd, ymax = value+sd), 
                linewidth = 0.2, width = .2, position = position_dodge(.9)) +
  labs(y = "MSE (CATE)", x = "Train size") +
  #scale_fill_brewer(palette = "Paired") +
  scale_fill_manual(values = c(RColorBrewer::brewer.pal(8, "Paired")[7],
                               RColorBrewer::brewer.pal(6, "Paired")[1:6])) +
  #coord_cartesian(ylim = c(0, 250)) +
  facet_wrap(. ~ shift_e, scales = "free",
             labeller = labeller(shift_e = shift_labs)) +
  facetted_pos_scales(y = list(shift_e == 0 ~ scale_y_continuous(limits = c(0, 275)),
                               shift_e == 1 ~ scale_y_continuous(limits = c(0, 200)),
                               shift_e == 2 ~ scale_y_continuous(limits = c(0, 200)))) +
  theme(legend.title = element_blank(),
        text = element_text(size = 10))

ggsave("s1d_2_test_cate-mse.pdf", width = 10, height = 2.75)
