# WHI data application
# Setup: Large observational study (OS) and RCT data (CT)

library(tidyverse)
library(ranger)
library(grf)
library(mcboost)
library(mlr3learners)
library(kableExtra)
library(ggh4x)
# library(plyr)
# devtools::install_github("soerenkuenzel/causalToolbox")
library(causalToolbox)
# devtools::install_github("xnie/rlearner")
library(rlearner)
source("dr_learner.R")

## Prepare loop

load("whi_data.RData")

train_res <- tibble()
train_temp <- tibble(rep_i = NA, rctsize_s = NA,
                    flag = NA,
                    y = list(tibble(y = NA)), 
                    ate = NA, 
                    tau_xrf = list(tibble(tau_xrf = NA)),
                    tau_drrf = list(tibble(tau_drrf = NA)),
                    tau_rlasso = list(tibble(tau_rlasso = NA)),
                    tau_tlasso = list(tibble(tau_tlasso = NA)),
                    y_cforest_tr = list(tibble(y_cforest_tr = NA)),
                    y_slearner_tr = list(tibble(y_slearner_tr = NA)),
                    y_tlearner_tr = list(tibble(y_tlearner_tr = NA)),
                    y_tclearner_tr = list(tibble(y_tclearner_tr = NA)),
                    yp_tlearner_tr_mcr = list(tibble(yp_tlearner_tr_mcr = NA)),
                    y_tlearner_tr_mcr = list(tibble(y_tlearner_tr_mcr = NA)),
                    yp_tlearner_tr_mct = list(tibble(yp_tlearner_tr_mct = NA)),
                    y_tlearner_tr_mct = list(tibble(y_tlearner_tr_mct = NA)),
                    yp_tlearner_tr_mclt = list(tibble(yp_tlearner_tr_mclt = NA)),
                    y_tlearner_tr_mclt = list(tibble(y_tlearner_tr_mclt = NA)),
                    yp_tlearner_tr_mclr = list(tibble(yp_tlearner_tr_mclr = NA)),
                    y_tlearner_tr_mclr = list(tibble(y_tlearner_tr_mclr = NA)),
                    tau_cforest_tr = list(tibble(tau_cforest_tr = NA)),
                    tau_slearner_tr = list(tibble(tau_slearner_tr = NA)),
                    tau_tlearner_tr = list(tibble(tau_tlearner_tr = NA)),
                    tau_tclearner_tr = list(tibble(tau_tclearner_tr = NA)),
                    tau_tlearner_tr_mcr = list(tibble(tau_tlearner_tr_mcr = NA)),
                    tau_tlearner_tr_mct = list(tibble(tau_tlearner_tr_mct = NA)),
                    tau_tlearner_tr_mclt = list(tibble(tau_tlearner_tr_mclt = NA)),
                    tau_tlearner_tr_mclr = list(tibble(tau_tlearner_tr_mclr = NA)),
                    tau_drlearner_tr = list(tibble(tau_drlearner_tr = NA)),
                    tau_drclearner_tr = list(tibble(tau_drclearner_tr = NA)),
                    tau_drlearner_tr_mcr = list(tibble(tau_drlearner_tr_mcr = NA)),
                    tau_drlearner_tr_mct = list(tibble(tau_drlearner_tr_mct = NA)),
                    tau_drlearner_tr_mcfr = list(tibble(tau_drlearner_tr_mcfr = NA)),
                    tau_drlearner_tr_mcft = list(tibble(tau_drlearner_tr_mcft = NA)),
                    tau_drlearner_tr_mclr = list(tibble(tau_drlearner_tr_mclr = NA)),
                    tau_drlearner_tr_mclt = list(tibble(tau_drlearner_tr_mclt = NA)),
                    y_cforest_rct = list(tibble(y_cforest_rct = NA)),
                    y_cforestw_rct = list(tibble(y_cforestw_rct = NA)),
                    y_slearner_rct = list(tibble(y_slearner_rct = NA)),
                    y_slearnerw_rct = list(tibble(y_slearnerw_rct = NA)),
                    y_tlearner_rct = list(tibble(y_tlearner_rct = NA)),
                    y_tlearnerw_rct = list(tibble(y_tlearnerw_rct = NA)),
                    tau_cforest_rct = list(tibble(tau_cforest_rct = NA)),
                    tau_cforestw_rct = list(tibble(tau_cforestw_rct = NA)),
                    tau_slearner_rct = list(tibble(tau_slearner_rct = NA)),
                    tau_slearnerw_rct = list(tibble(tau_slearnerw_rct = NA)),
                    tau_tlearner_rct = list(tibble(tau_tlearner_rct = NA)),
                    tau_tlearnerw_rct = list(tibble(tau_tlearnerw_rct = NA)))

test_res <- tibble()
test_temp <- tibble(rep_i = NA, rctsize_s = NA,
                    flag = NA,
                    y = list(tibble(y = NA)), 
                    ate = NA, 
                    tau_xrf = list(tibble(tau_xrf = NA)),
                    tau_drrf = list(tibble(tau_drrf = NA)),
                    tau_rlasso = list(tibble(tau_rlasso = NA)),
                    tau_tlasso = list(tibble(tau_tlasso = NA)),
                    y_cforest_tr = list(tibble(y_cforest_tr = NA)),
                    y_slearner_tr = list(tibble(y_slearner_tr = NA)),
                    y_tlearner_tr = list(tibble(y_tlearner_tr = NA)),
                    y_tclearner_tr = list(tibble(y_tclearner_tr = NA)),
                    yp_tlearner_tr_mcr = list(tibble(yp_tlearner_tr_mcr = NA)),
                    y_tlearner_tr_mcr = list(tibble(y_tlearner_tr_mcr = NA)),
                    yp_tlearner_tr_mct = list(tibble(yp_tlearner_tr_mct = NA)),
                    y_tlearner_tr_mct = list(tibble(y_tlearner_tr_mct = NA)),
                    yp_tlearner_tr_mclt = list(tibble(yp_tlearner_tr_mclt = NA)),
                    y_tlearner_tr_mclt = list(tibble(y_tlearner_tr_mclt = NA)),
                    yp_tlearner_tr_mclr = list(tibble(yp_tlearner_tr_mclr = NA)),
                    y_tlearner_tr_mclr = list(tibble(y_tlearner_tr_mclr = NA)),
                    tau_cforest_tr = list(tibble(tau_cforest_tr = NA)),
                    tau_slearner_tr = list(tibble(tau_slearner_tr = NA)),
                    tau_tlearner_tr = list(tibble(tau_tlearner_tr = NA)),
                    tau_tclearner_tr = list(tibble(tau_tclearner_tr = NA)),
                    tau_tlearner_tr_mcr = list(tibble(tau_tlearner_tr_mcr = NA)),
                    tau_tlearner_tr_mct = list(tibble(tau_tlearner_tr_mct = NA)),
                    tau_tlearner_tr_mclt = list(tibble(tau_tlearner_tr_mclt = NA)),
                    tau_tlearner_tr_mclr = list(tibble(tau_tlearner_tr_mclr = NA)),
                    tau_drlearner_tr = list(tibble(tau_drlearner_tr = NA)),
                    tau_drclearner_tr = list(tibble(tau_drclearner_tr = NA)),
                    tau_drlearner_tr_mcr = list(tibble(tau_drlearner_tr_mcr = NA)),
                    tau_drlearner_tr_mct = list(tibble(tau_drlearner_tr_mct = NA)),
                    tau_drlearner_tr_mcfr = list(tibble(tau_drlearner_tr_mcfr = NA)),
                    tau_drlearner_tr_mcft = list(tibble(tau_drlearner_tr_mcft = NA)),
                    tau_drlearner_tr_mclr = list(tibble(tau_drlearner_tr_mclr = NA)),
                    tau_drlearner_tr_mclt = list(tibble(tau_drlearner_tr_mclt = NA)),
                    y_cforest_rct = list(tibble(y_cforest_rct = NA)),
                    y_cforestw_rct = list(tibble(y_cforestw_rct = NA)),
                    y_slearner_rct = list(tibble(y_slearner_rct = NA)),
                    y_slearnerw_rct = list(tibble(y_slearnerw_rct = NA)),
                    y_tlearner_rct = list(tibble(y_tlearner_rct = NA)),
                    y_tlearnerw_rct = list(tibble(y_tlearnerw_rct = NA)),
                    tau_cforest_rct = list(tibble(tau_cforest_rct = NA)),
                    tau_cforestw_rct = list(tibble(tau_cforestw_rct = NA)),
                    tau_slearner_rct = list(tibble(tau_slearner_rct = NA)),
                    tau_slearnerw_rct = list(tibble(tau_slearnerw_rct = NA)),
                    tau_tlearner_rct = list(tibble(tau_tlearner_rct = NA)),
                    tau_tlearnerw_rct = list(tibble(tau_tlearnerw_rct = NA)))

s_range <- seq(250, 1500, by = 250) # n obs rct data
n_reps <- 25 # n repetitions

scale <- function(x, label){
  (x - min(label))/(max(label) - min(label))
}

rev_scale <- function(preds, label){
  return(preds*(max(label) - min(label)) + min(label))
}

ridge <- LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 0, s = 10))
tree <- LearnerAuditorFitter$new(lrn("regr.rpart", maxdepth = 3))
subT <- SubpopAuditorFitter$new(list("T"))

# Reduced set of covars
ct_num <- select(ct_num, Y, T, age, contains("ethnic"), contains("cigsday"), syst_bl, dias_bl, bmix_bl)
os_num <- select(os_num, Y, T, age, contains("ethnic"), contains("cigsday"), syst_bl, dias_bl, bmix_bl)

# Small set of covars
ct_num <- select(ct_num, Y, T, age, contains("ethnic"))
os_num <- select(os_num, Y, T, age, contains("ethnic"))


## Simulation

for(s in s_range) {
  
  ## Set RCT set size
  train_size <- s
  train_temp$rctsize_s <- s
  test_temp$rctsize_s <- s
  
  for(i in 1:n_reps) {
    
    train_temp$rep_i <- i
    test_temp$rep_i <- i

    # OS train and test
    # os_num <- slice_sample(os_num, n = 20000)
    os_ids <- sample(1:nrow(os_num), 0.75*nrow(os_num))
    os_train <- os_num[os_ids,]
    os_test <- os_num[-os_ids,]
    
    # RCT full train and test
    ct_ids <- sample(1:nrow(ct_num), 0.5*nrow(ct_num)) 
    ct_train_full <- ct_num[ct_ids,]
    ct_test <- ct_num[-ct_ids,]

    ## Learn 'true' CATE in RCT test
    feat <- select(ct_test, -c(Y, T))
    tr <- ct_test$T
    yobs <- ct_test$Y
    
    xl_rf <- X_RF(feat, tr, yobs) # RF X-learner
    dr_rf <- dr_learner_grf(ct_test[,-1], yobs, tr, ct_test[,-1], trunc = 0.01) # RF DR-learner 
    rl_lasso <- rlasso(as.matrix(feat), tr, yobs, alpha = 0) # glmnet R-learner
    tl_lasso <- tlasso(as.matrix(feat), tr, yobs) # glmnet T-learner

    ## Sample RCT train
    ct_train <- slice_sample(ct_train_full, n = train_size)

    ## Store 'true' ATE and CATE 
    train_temp$y <- list(tibble(y = os_test$Y)) # true Y
    test_temp$y <- list(tibble(y = ct_test$Y))

    train_temp$ate <- mean(ct_test$Y[ct_test$T==1]) - mean(ct_test$Y[ct_test$T==0]) # true ATE
    test_temp$ate <- mean(ct_test$Y[ct_test$T==1]) - mean(ct_test$Y[ct_test$T==0])
    
    train_temp$tau_xrf <- list(tibble(tau_xrf = EstimateCate(xl_rf, as.data.frame(select(os_test, -c(Y, T)))))) # 'true' CATE
    train_temp$tau_drrf <- list(tibble(tau_drrf = dr_rf$tau.hat))
    train_temp$tau_rlasso <- list(tibble(tau_rlasso = predict(rl_lasso, as.matrix(select(os_test, -c(Y, T))))))
    train_temp$tau_tlasso <- list(tibble(tau_tlasso = predict(tl_lasso, as.matrix(select(os_test, -c(Y, T))))))
    test_temp$tau_xrf <- list(tibble(tau_xrf = EstimateCate(xl_rf, as.data.frame(select(ct_test, -c(Y, T))))))
    test_temp$tau_drrf <- list(tibble(tau_drrf = dr_rf$tau.hat))
    test_temp$tau_rlasso <- list(tibble(tau_rlasso = predict(rl_lasso, as.matrix(select(ct_test, -c(Y, T))))))
    test_temp$tau_tlasso <- list(tibble(tau_tlasso = predict(tl_lasso, as.matrix(select(ct_test, -c(Y, T))))))
    
    ## Pre-process data
    X_traint <- select(os_train, -Y) # OS train data 
    Y_train <- os_train$Y
    Y_trains <- scale(Y_train, label = -2*min(Y_train) + 2*Y_train) # scale range of y
    T_train <- as.numeric(os_train$T)
    
    X_rctt <- select(ct_train, -Y) # RCT train data
    Y_rct <- ct_train$Y
    Y_rcts <- scale(Y_rct, label = -2*min(Y_train) + 2*Y_train) # scale range of y
    test_temp$flag <- ifelse(min(Y_rcts) < 0 | max(Y_rcts) > 1, 1, 0) # outside [0, 1]?
    Y_rcts <- ifelse(Y_rcts < 0, 0, Y_rcts) # clip to [0, 1]
    Y_rcts <- ifelse(Y_rcts > 1, 1, Y_rcts)
    T_rct <- as.numeric(ct_train$T)
    
    X_os_test <- select(os_test, -c(Y, T)) # OS test data
    Xt_os_test <- select(os_test, -Y)
    X_ct_test <- select(ct_test, -c(Y, T)) # RCT test data
    Xt_ct_test <- select(ct_test, -Y)

    os_ut <- data.frame(X_os_test, T = 0) # Fix treated and untreated
    os_t <- data.frame(X_os_test, T = 1)
    test_ut <- data.frame(X_ct_test, T = 0)
    test_t <- data.frame(X_ct_test, T = 1)
      
    ## Propensity score model - Train vs RCT train
      
    stacked <- bind_rows(X_traint[-T], X_rctt[-T], .id = "rct")
    stacked$rct <- as.numeric(stacked$rct) - 1
    psm <- glm(rct ~ ., family = binomial, data = stacked)
    pscores <- predict(psm, newdata = X_rctt[-T], type = "response")
    pweights <- (1 - pscores) / pscores
      
    ## Train models - Train w. observational data, post-process w. RCT train data
    ### Causal Forest
    cforest_tr <- causal_forest(X_traint[-T], Y_train, T_train)

    train_temp$tau_cforest_tr <- list(tibble(tau_cforest_tr = predict(cforest_tr, X_os_test)$predictions))
    test_temp$tau_cforest_tr <- list(tibble(tau_cforest_tr = predict(cforest_tr, X_ct_test)$predictions))

    ### S-learner
    slearner_tr <- ranger(y = Y_trains, x = X_traint)

    yp_sl_tr_t <- predict(slearner_tr, os_t)$predictions
    yp_sl_tr_ut <- predict(slearner_tr, os_ut)$predictions
    y_sl_tr_t <- rev_scale(yp_sl_tr_t, label = -2*min(Y_train) + 2*Y_train)
    y_sl_tr_ut <- rev_scale(yp_sl_tr_ut, label = -2*min(Y_train) + 2*Y_train)
    train_temp$y_slearner_tr <- list(tibble(y_slearner_tr = rev_scale(predict(slearner_tr, os_test)$predictions, label = -2*min(Y_train) + 2*Y_train)))
    train_temp$tau_slearner_tr <- list(tibble(tau_slearner_tr = y_sl_tr_t - y_sl_tr_ut))
    
    yp_sl_tr_t <- predict(slearner_tr, test_t)$predictions
    yp_sl_tr_ut <- predict(slearner_tr, test_ut)$predictions
    y_sl_tr_t <- rev_scale(yp_sl_tr_t, label = -2*min(Y_train) + 2*Y_train)
    y_sl_tr_ut <- rev_scale(yp_sl_tr_ut, label = -2*min(Y_train) + 2*Y_train)
    test_temp$y_slearner_tr <- list(tibble(y_slearner_tr = rev_scale(predict(slearner_tr, ct_test)$predictions, label = -2*min(Y_train) + 2*Y_train)))
    test_temp$tau_slearner_tr <- list(tibble(tau_slearner_tr = y_sl_tr_t - y_sl_tr_ut))

    ### T-learner (ranger)
    tlearner_tr_t <- ranger(y = Y_trains[X_traint$T == 1], 
                            x = X_traint[X_traint$T == 1, ])
    tlearner_tr_ut <- ranger(y = Y_trains[X_traint$T == 0], 
                             x = X_traint[X_traint$T == 0, ])

    yp_tl_tr_t <- predict(tlearner_tr_t, os_t)$predictions
    yp_tl_tr_ut <- predict(tlearner_tr_ut, os_ut)$predictions
    y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = -2*min(Y_train) + 2*Y_train)
    y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = -2*min(Y_train) + 2*Y_train)
    train_temp$y_tlearner_tr <- list(tibble(y_tlearner_tr = ifelse(os_test$T == 1, 
                                            rev_scale(predict(tlearner_tr_t, os_test)$predictions, label = -2*min(Y_train) + 2*Y_train), 
                                            rev_scale(predict(tlearner_tr_ut, os_test)$predictions, label = -2*min(Y_train) + 2*Y_train))))
    train_temp$tau_tlearner_tr <- list(tibble(tau_tlearner_tr = y_tl_tr_t - y_tl_tr_ut))
    
    yp_tl_tr_t <- predict(tlearner_tr_t, test_t)$predictions
    yp_tl_tr_ut <- predict(tlearner_tr_ut, test_ut)$predictions
    y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = -2*min(Y_train) + 2*Y_train)
    y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = -2*min(Y_train) + 2*Y_train)
    test_temp$y_tlearner_tr <- list(tibble(y_tlearner_tr = ifelse(ct_test$T == 1, 
                                           rev_scale(predict(tlearner_tr_t, ct_test)$predictions, label = -2*min(Y_train) + 2*Y_train), 
                                           rev_scale(predict(tlearner_tr_ut, ct_test)$predictions, label = -2*min(Y_train) + 2*Y_train))))
    test_temp$tau_tlearner_tr <- list(tibble(tau_tlearner_tr = y_tl_tr_t - y_tl_tr_ut))
    
    ### T-learner (grf)
    tclearner_tr_t <- regression_forest(Y = Y_trains[X_traint$T == 1], 
                                        X = X_traint[X_traint$T == 1, ])
    tclearner_tr_ut <- regression_forest(Y = Y_trains[X_traint$T == 0], 
                                         X = X_traint[X_traint$T == 0, ])
    
    yp_tl_tr_t <- predict(tclearner_tr_t, os_t)$predictions
    yp_tl_tr_ut <- predict(tclearner_tr_ut, os_ut)$predictions
    y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = -2*min(Y_train) + 2*Y_train)
    y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = -2*min(Y_train) + 2*Y_train)
    train_temp$y_tclearner_tr <- list(tibble(y_tclearner_tr = ifelse(os_test$T == 1, 
                                            rev_scale(predict(tclearner_tr_t, Xt_os_test)$predictions, label = -2*min(Y_train) + 2*Y_train), 
                                            rev_scale(predict(tclearner_tr_ut, Xt_os_test)$predictions, label = -2*min(Y_train) + 2*Y_train))))
    train_temp$tau_tclearner_tr <- list(tibble(tau_tclearner_tr = y_tl_tr_t - y_tl_tr_ut))
    
    yp_tl_tr_t <- predict(tclearner_tr_t, test_t)$predictions
    yp_tl_tr_ut <- predict(tclearner_tr_ut, test_ut)$predictions
    y_tl_tr_t <- rev_scale(yp_tl_tr_t, label = -2*min(Y_train) + 2*Y_train)
    y_tl_tr_ut <- rev_scale(yp_tl_tr_ut, label = -2*min(Y_train) + 2*Y_train)
    test_temp$y_tclearner_tr <- list(tibble(y_tclearner_tr = ifelse(ct_test$T == 1, 
                                           rev_scale(predict(tclearner_tr_t, Xt_ct_test)$predictions, label = -2*min(Y_train) + 2*Y_train), 
                                           rev_scale(predict(tclearner_tr_ut, Xt_ct_test)$predictions, label = -2*min(Y_train) + 2*Y_train))))
    test_temp$tau_tclearner_tr <- list(tibble(tau_tclearner_tr = y_tl_tr_t - y_tl_tr_ut))
    
    ### T-learner + MCBoost (ridge)
    init_preds = function(data) {
      preds <- predict(tclearner_tr_t, data)$predictions}
    tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                   auditor_fitter = ridge,
                                   alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                   weight_degree = 2,
                                   eta = 0.1,
                                   max_iter = 10)
    tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])

    yp_tlearner_tr_t_mc_os_test <- tlearner_tr_t_mc$predict_probs(Xt_os_test)
    yp_tlearner_tr_t_mc_os_t <- tlearner_tr_t_mc$predict_probs(os_t)
    y_tlearner_tr_t_mc_os_t <- rev_scale(yp_tlearner_tr_t_mc_os_t, label = -2*min(Y_train) + 2*Y_train)

    yp_tlearner_tr_t_mc_ct_test <- tlearner_tr_t_mc$predict_probs(Xt_ct_test)
    yp_tlearner_tr_t_mc_ct_t <- tlearner_tr_t_mc$predict_probs(test_t)
    y_tlearner_tr_t_mc_ct_t <- rev_scale(yp_tlearner_tr_t_mc_ct_t, label = -2*min(Y_train) + 2*Y_train)
    
    init_preds = function(data) {
      preds <- predict(tclearner_tr_ut, data)$predictions}
    tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                    auditor_fitter = ridge,
                                    alpha = 1e-06,
                                    # iter_sampling = "bootstrap",
                                    weight_degree = 2,
                                    eta = 0.1,
                                    max_iter = 10)
    tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])

    yp_tlearner_tr_ut_mc_os_test <- tlearner_tr_ut_mc$predict_probs(Xt_os_test)
    yp_tlearner_tr_ut_mc_os_ut <- tlearner_tr_ut_mc$predict_probs(os_ut)
    y_tlearner_tr_ut_mc_os_ut <- rev_scale(yp_tlearner_tr_ut_mc_os_ut, label = -2*min(Y_train) + 2*Y_train)
    
    yp_tlearner_tr_ut_mc_ct_test <- tlearner_tr_ut_mc$predict_probs(Xt_ct_test)
    yp_tlearner_tr_ut_mc_ct_ut <- tlearner_tr_ut_mc$predict_probs(test_ut)
    y_tlearner_tr_ut_mc_ct_ut <- rev_scale(yp_tlearner_tr_ut_mc_ct_ut, label = -2*min(Y_train) + 2*Y_train)
    
    train_temp$yp_tlearner_tr_mcr <- list(tibble(yp_tlearner_tr_mcr = ifelse(os_test$T == 1, yp_tlearner_tr_t_mc_os_test, yp_tlearner_tr_ut_mc_os_test)))
    train_temp$y_tlearner_tr_mcr <- list(tibble(y_tlearner_tr_mcr = rev_scale(train_temp$yp_tlearner_tr_mcr[[1]], label = -2*min(Y_train) + 2*Y_train)))
    train_temp$tau_tlearner_tr_mcr <- list(tibble(tau_tlearner_tr_mcr = y_tlearner_tr_t_mc_os_t - y_tlearner_tr_ut_mc_os_ut))
    
    test_temp$yp_tlearner_tr_mcr <- list(tibble(yp_tlearner_tr_mcr = ifelse(ct_test$T == 1, yp_tlearner_tr_t_mc_ct_test, yp_tlearner_tr_ut_mc_ct_test)))
    test_temp$y_tlearner_tr_mcr <- list(tibble(y_tlearner_tr_mcr = rev_scale(test_temp$yp_tlearner_tr_mcr[[1]], label = -2*min(Y_train) + 2*Y_train)))
    test_temp$tau_tlearner_tr_mcr <- list(tibble(tau_tlearner_tr_mcr = y_tlearner_tr_t_mc_ct_t - y_tlearner_tr_ut_mc_ct_ut))

    ### T-learner + MCBoost (tree)
    init_preds = function(data) {
      preds <- predict(tclearner_tr_t, data)$predictions}
    tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                   auditor_fitter = tree,
                                   alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                   weight_degree = 2,
                                   eta = 0.1,
                                   max_iter = 10)
    tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])
    
    yp_tlearner_tr_t_mc_os_test <- tlearner_tr_t_mc$predict_probs(Xt_os_test)
    yp_tlearner_tr_t_mc_os_t <- tlearner_tr_t_mc$predict_probs(os_t)
    y_tlearner_tr_t_mc_os_t <- rev_scale(yp_tlearner_tr_t_mc_os_t, label = -2*min(Y_train) + 2*Y_train)
    
    yp_tlearner_tr_t_mc_ct_test <- tlearner_tr_t_mc$predict_probs(Xt_ct_test)
    yp_tlearner_tr_t_mc_ct_t <- tlearner_tr_t_mc$predict_probs(test_t)
    y_tlearner_tr_t_mc_ct_t <- rev_scale(yp_tlearner_tr_t_mc_ct_t, label = -2*min(Y_train) + 2*Y_train)
    
    init_preds = function(data) {
      preds <- predict(tclearner_tr_ut, data)$predictions}
    tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                    auditor_fitter = tree,
                                    alpha = 1e-06,
                                    # iter_sampling = "bootstrap",
                                    weight_degree = 2,
                                    eta = 0.1,
                                    max_iter = 10)
    tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])

    yp_tlearner_tr_ut_mc_os_test <- tlearner_tr_ut_mc$predict_probs(Xt_os_test)
    yp_tlearner_tr_ut_mc_os_ut <- tlearner_tr_ut_mc$predict_probs(os_ut)
    y_tlearner_tr_ut_mc_os_ut <- rev_scale(yp_tlearner_tr_ut_mc_os_ut, label = -2*min(Y_train) + 2*Y_train)
    
    yp_tlearner_tr_ut_mc_ct_test <- tlearner_tr_ut_mc$predict_probs(Xt_ct_test)
    yp_tlearner_tr_ut_mc_ct_ut <- tlearner_tr_ut_mc$predict_probs(test_ut)
    y_tlearner_tr_ut_mc_ct_ut <- rev_scale(yp_tlearner_tr_ut_mc_ct_ut, label = -2*min(Y_train) + 2*Y_train)
    
    train_temp$yp_tlearner_tr_mct <- list(tibble(yp_tlearner_tr_mct = ifelse(os_test$T == 1, yp_tlearner_tr_t_mc_os_test, yp_tlearner_tr_ut_mc_os_test)))
    train_temp$y_tlearner_tr_mct <- list(tibble(y_tlearner_tr_mct = rev_scale(train_temp$yp_tlearner_tr_mct[[1]], label = -2*min(Y_train) + 2*Y_train)))
    train_temp$tau_tlearner_tr_mct <- list(tibble(tau_tlearner_tr_mct = y_tlearner_tr_t_mc_os_t - y_tlearner_tr_ut_mc_os_ut))
    
    test_temp$yp_tlearner_tr_mct <- list(tibble(yp_tlearner_tr_mct = ifelse(ct_test$T == 1, yp_tlearner_tr_t_mc_ct_test, yp_tlearner_tr_ut_mc_ct_test)))
    test_temp$y_tlearner_tr_mct <- list(tibble(y_tlearner_tr_mct = rev_scale(test_temp$yp_tlearner_tr_mct[[1]], label = -2*min(Y_train) + 2*Y_train)))
    test_temp$tau_tlearner_tr_mct <- list(tibble(tau_tlearner_tr_mct = y_tlearner_tr_t_mc_ct_t - y_tlearner_tr_ut_mc_ct_ut))
    
    ### T-learner + MCBoost (tree max_iter 20)
    init_preds = function(data) {
      preds <- predict(tclearner_tr_t, data)$predictions}
    tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                   auditor_fitter = tree,
                                   alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                   weight_degree = 2,
                                   eta = 0.5,
                                   max_iter = 20)
    tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])
    
    yp_tlearner_tr_t_mc_os_test <- tlearner_tr_t_mc$predict_probs(Xt_os_test)
    yp_tlearner_tr_t_mc_os_t <- tlearner_tr_t_mc$predict_probs(os_t)
    y_tlearner_tr_t_mc_os_t <- rev_scale(yp_tlearner_tr_t_mc_os_t, label = -2*min(Y_train) + 2*Y_train)
    
    yp_tlearner_tr_t_mc_ct_test <- tlearner_tr_t_mc$predict_probs(Xt_ct_test)
    yp_tlearner_tr_t_mc_ct_t <- tlearner_tr_t_mc$predict_probs(test_t)
    y_tlearner_tr_t_mc_ct_t <- rev_scale(yp_tlearner_tr_t_mc_ct_t, label = -2*min(Y_train) + 2*Y_train)
    
    init_preds = function(data) {
      preds <- predict(tclearner_tr_ut, data)$predictions}
    tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                    auditor_fitter = tree,
                                    alpha = 1e-06,
                                    # iter_sampling = "bootstrap",
                                    weight_degree = 2,
                                    eta = 0.5,
                                    max_iter = 20)
    tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])

    yp_tlearner_tr_ut_mc_os_test <- tlearner_tr_ut_mc$predict_probs(Xt_os_test)
    yp_tlearner_tr_ut_mc_os_ut <- tlearner_tr_ut_mc$predict_probs(os_ut)
    y_tlearner_tr_ut_mc_os_ut <- rev_scale(yp_tlearner_tr_ut_mc_os_ut, label = -2*min(Y_train) + 2*Y_train)
    
    yp_tlearner_tr_ut_mc_ct_test <- tlearner_tr_ut_mc$predict_probs(Xt_ct_test)
    yp_tlearner_tr_ut_mc_ct_ut <- tlearner_tr_ut_mc$predict_probs(test_ut)
    y_tlearner_tr_ut_mc_ct_ut <- rev_scale(yp_tlearner_tr_ut_mc_ct_ut, label = -2*min(Y_train) + 2*Y_train)
    
    train_temp$yp_tlearner_tr_mclt <- list(tibble(yp_tlearner_tr_mclt = ifelse(os_test$T == 1, yp_tlearner_tr_t_mc_os_test, yp_tlearner_tr_ut_mc_os_test)))
    train_temp$y_tlearner_tr_mclt <- list(tibble(y_tlearner_tr_mclt = rev_scale(train_temp$yp_tlearner_tr_mclt[[1]], label = -2*min(Y_train) + 2*Y_train)))
    train_temp$tau_tlearner_tr_mclt <- list(tibble(tau_tlearner_tr_mclt = y_tlearner_tr_t_mc_os_t - y_tlearner_tr_ut_mc_os_ut))
    
    test_temp$yp_tlearner_tr_mclt <- list(tibble(yp_tlearner_tr_mclt = ifelse(ct_test$T == 1, yp_tlearner_tr_t_mc_ct_test, yp_tlearner_tr_ut_mc_ct_test)))
    test_temp$y_tlearner_tr_mclt <- list(tibble(y_tlearner_tr_mclt = rev_scale(test_temp$yp_tlearner_tr_mclt[[1]], label = -2*min(Y_train) + 2*Y_train)))
    test_temp$tau_tlearner_tr_mclt <- list(tibble(tau_tlearner_tr_mclt = y_tlearner_tr_t_mc_ct_t - y_tlearner_tr_ut_mc_ct_ut))

    ### T-learner + MCBoost (ridge max_iter 20)
    init_preds = function(data) {
      preds <- predict(tclearner_tr_t, data)$predictions}
    tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                   auditor_fitter = ridge,
                                   alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                   weight_degree = 2,
                                   eta = 0.5,
                                   max_iter = 20)
    tlearner_tr_t_mc$multicalibrate(X_rctt[X_rctt$T == 1, ], Y_rcts[X_rctt$T == 1])
    
    yp_tlearner_tr_t_mc_os_test <- tlearner_tr_t_mc$predict_probs(Xt_os_test)
    yp_tlearner_tr_t_mc_os_t <- tlearner_tr_t_mc$predict_probs(os_t)
    y_tlearner_tr_t_mc_os_t <- rev_scale(yp_tlearner_tr_t_mc_os_t, label = -2*min(Y_train) + 2*Y_train)
    
    yp_tlearner_tr_t_mc_ct_test <- tlearner_tr_t_mc$predict_probs(Xt_ct_test)
    yp_tlearner_tr_t_mc_ct_t <- tlearner_tr_t_mc$predict_probs(test_t)
    y_tlearner_tr_t_mc_ct_t <- rev_scale(yp_tlearner_tr_t_mc_ct_t, label = -2*min(Y_train) + 2*Y_train)
    
    init_preds = function(data) {
      preds <- predict(tclearner_tr_ut, data)$predictions}
    tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                    auditor_fitter = ridge,
                                    alpha = 1e-06,
                                    # iter_sampling = "bootstrap",
                                    weight_degree = 2,
                                    eta = 0.5,
                                    max_iter = 20)
    tlearner_tr_ut_mc$multicalibrate(X_rctt[X_rctt$T == 0, ], Y_rcts[X_rctt$T == 0])
    
    yp_tlearner_tr_ut_mc_os_test <- tlearner_tr_ut_mc$predict_probs(Xt_os_test)
    yp_tlearner_tr_ut_mc_os_ut <- tlearner_tr_ut_mc$predict_probs(os_ut)
    y_tlearner_tr_ut_mc_os_ut <- rev_scale(yp_tlearner_tr_ut_mc_os_ut, label = -2*min(Y_train) + 2*Y_train)
    
    yp_tlearner_tr_ut_mc_ct_test <- tlearner_tr_ut_mc$predict_probs(Xt_ct_test)
    yp_tlearner_tr_ut_mc_ct_ut <- tlearner_tr_ut_mc$predict_probs(test_ut)
    y_tlearner_tr_ut_mc_ct_ut <- rev_scale(yp_tlearner_tr_ut_mc_ct_ut, label = -2*min(Y_train) + 2*Y_train)
    
    train_temp$yp_tlearner_tr_mclr <- list(tibble(yp_tlearner_tr_mclr = ifelse(os_test$T == 1, yp_tlearner_tr_t_mc_os_test, yp_tlearner_tr_ut_mc_os_test)))
    train_temp$y_tlearner_tr_mclr <- list(tibble(y_tlearner_tr_mclr = rev_scale(train_temp$yp_tlearner_tr_mclr[[1]], label = -2*min(Y_train) + 2*Y_train)))
    train_temp$tau_tlearner_tr_mclr <- list(tibble(tau_tlearner_tr_mclr = y_tlearner_tr_t_mc_os_t - y_tlearner_tr_ut_mc_os_ut))
    
    test_temp$yp_tlearner_tr_mclr <- list(tibble(yp_tlearner_tr_mclr = ifelse(ct_test$T == 1, yp_tlearner_tr_t_mc_ct_test, yp_tlearner_tr_ut_mc_ct_test)))
    test_temp$y_tlearner_tr_mclr <- list(tibble(y_tlearner_tr_mclr = rev_scale(test_temp$yp_tlearner_tr_mclr[[1]], label = -2*min(Y_train) + 2*Y_train)))
    test_temp$tau_tlearner_tr_mclr <- list(tibble(tau_tlearner_tr_mclr = y_tlearner_tr_t_mc_ct_t - y_tlearner_tr_ut_mc_ct_ut))
    
    ### DR-learner
    drlearner_tr <- dr_learner(X_traint[,-T], Y_train, T_train, X_ct_test, trunc = 0.02)
    
    train_temp$tau_drlearner_tr <- list(tibble(tau_drlearner_tr = drlearner_tr$tau.hat))
    test_temp$tau_drlearner_tr <- list(tibble(tau_drlearner_tr = drlearner_tr$tau.new))
    
    drclearner_tr <- dr_learner_grf(X_traint[,-T], Y_train, T_train, X_ct_test, trunc = 0.02)
    
    train_temp$tau_drclearner_tr <- list(tibble(tau_drclearner_tr = drclearner_tr$tau.hat))
    test_temp$tau_drclearner_tr <- list(tibble(tau_drclearner_tr = drclearner_tr$tau.new))
    
    ### DR-learner + MCBoost (ridge)
    drlearner_tr_mcr <- try(dr_learnermc_grf(X_traint[,-T], X_traint, X_rctt,
                                         Y_train, Y_trains, Y_rcts,
                                         T_train,
                                         X_ct_test, iter = 10))
    
    train_temp$tau_drlearner_tr_mcr <- tibble(tau_drlearner_tr_mcr = NA)
    test_temp$tau_drlearner_tr_mcr <- tibble(tau_drlearner_tr_mcr = NA)
    train_temp$tau_drlearner_tr_mcr <- try(list(tibble(tau_drlearner_tr_mcr = drlearner_tr_mcr$tau.hat)))
    test_temp$tau_drlearner_tr_mcr <- try(list(tibble(tau_drlearner_tr_mcr = drlearner_tr_mcr$tau.new)))
    
    ### DR-learner + MCBoost (tree)
    drlearner_tr_mct <- dr_learnermc_grf(X_traint[,-T], X_traint, X_rctt,
                                         Y_train, Y_trains, Y_rcts,
                                         T_train,
                                         X_ct_test, iter = 10, auditor = "TreeAuditorFitter")
    
    train_temp$tau_drlearner_tr_mct <- list(tibble(tau_drlearner_tr_mct = drlearner_tr_mct$tau.hat))
    test_temp$tau_drlearner_tr_mct <- list(tibble(tau_drlearner_tr_mct = drlearner_tr_mct$tau.new))

    ### DR-learner + MCBoost (ridge)
    drlearner_tr_mcfr <- try(dr_learnermc2_grf(X_traint[,-T], X_traint, X_rctt,
                                               Y_train, Y_trains, Y_rct, Y_rcts,
                                               T_train, T_rct,
                                               X_ct_test, eta = 0.01)) # 0.1
    
    train_temp$tau_drlearner_tr_mcfr <- tibble(tau_drlearner_tr_mcfr = NA)
    test_temp$tau_drlearner_tr_mcfr <- tibble(tau_drlearner_tr_mcfr = NA)
    train_temp$tau_drlearner_tr_mcfr <- try(list(tibble(tau_drlearner_tr_mcfr = drlearner_tr_mcfr$tau.hat)))
    test_temp$tau_drlearner_tr_mcfr <- try(list(tibble(tau_drlearner_tr_mcfr = drlearner_tr_mcfr$tau.new)))
    
    ### DR-learner + MCBoost (tree)
    drlearner_tr_mcft <- dr_learnermc2_grf(X_traint[,-T], X_traint, X_rctt,
                                           Y_train, Y_trains, Y_rct, Y_rcts,
                                           T_train, T_rct,
                                           X_ct_test, eta = 0.01, auditor = "TreeAuditorFitter") # 0.1
    
    train_temp$tau_drlearner_tr_mcft <- list(tibble(tau_drlearner_tr_mcft = drlearner_tr_mcft$tau.hat))
    test_temp$tau_drlearner_tr_mcft <- list(tibble(tau_drlearner_tr_mcft = drlearner_tr_mcft$tau.new))
    
    ### DR-learner + MCBoost (ridge)
    drlearner_tr_mclr <- try(dr_learnermc3_grf(X_traint[,-T], X_traint, X_rctt,
                                               Y_train, Y_rct,
                                               T_train, T_rct,
                                               X_ct_test, eta = 0.1))
    
    train_temp$tau_drlearner_tr_mclr <- tibble(tau_drlearner_tr_mclr = NA)
    test_temp$tau_drlearner_tr_mclr <- tibble(tau_drlearner_tr_mclr = NA)
    train_temp$tau_drlearner_tr_mclr <- try(list(tibble(tau_drlearner_tr_mclr = drlearner_tr_mclr$tau.hat)))
    test_temp$tau_drlearner_tr_mclr <- try(list(tibble(tau_drlearner_tr_mclr = drlearner_tr_mclr$tau.new)))
    
    ### DR-learner + MCBoost (tree)
    drlearner_tr_mclt <-  dr_learnermc3_grf(X_traint[,-T], X_traint, X_rctt,
                                           Y_train, Y_rct,
                                           T_train, T_rct,
                                           X_ct_test, eta = 0.1, auditor = "TreeAuditorFitter")
    
    train_temp$tau_drlearner_tr_mclt <- list(tibble(tau_drlearner_tr_mclt = drlearner_tr_mclt$tau.hat))
    test_temp$tau_drlearner_tr_mclt <- list(tibble(tau_drlearner_tr_mclt = drlearner_tr_mclt$tau.new))
    
    ## Train models - Train w. RCT train data
    ### Causal Forest
    cforest_rct <- causal_forest(X_rctt[-T], Y_rct, T_rct)
    
    test_temp$tau_cforest_rct <- list(tibble(tau_cforest_rct = predict(cforest_rct, X_ct_test)$predictions))
    
    ### Causal Forest (weighted)
    cforestw_rct <- causal_forest(X_rctt[-T], Y_rct, T_rct, 
                                  sample.weights = pweights)
      
    test_temp$tau_cforestw_rct <- list(tibble(tau_cforestw_rct = predict(cforestw_rct, X_ct_test)$predictions))
      
    ### S-learner
    slearner_rct <- ranger(y = Y_rct, x = X_rctt)
    
    test_temp$y_slearner_rct <- list(tibble(y_slearner_rct = predict(slearner_rct, ct_test)$predictions))
    test_temp$tau_slearner_rct <- list(tibble(tau_slearner_rct = predict(slearner_rct, test_t)$predictions - predict(slearner_rct, test_ut)$predictions))

    ### S-learner (weighted)
    slearnerw_rct <- ranger(y = Y_rct, x = X_rctt, 
                            case.weights = pweights)

    test_temp$y_slearnerw_rct <- list(tibble(y_slearnerw_rct = predict(slearnerw_rct, ct_test)$predictions))
    test_temp$tau_slearnerw_rct <- list(tibble(tau_slearnerw_rct = predict(slearnerw_rct, test_t)$predictions - predict(slearnerw_rct, test_ut)$predictions))
      
    ### T-learner
    tlearner_rct_t <- ranger(y = Y_rct[X_rctt$T == 1], 
                             x = X_rctt[X_rctt$T == 1, ])
    tlearner_rct_ut <- ranger(y = Y_rct[X_rctt$T == 0], 
                              x = X_rctt[X_rctt$T == 0, ])
    
    test_temp$y_tlearner_rct <- list(tibble(y_tlearner_rct = ifelse(ct_test$T == 1, 
                                            predict(tlearner_rct_t, ct_test)$predictions,
                                            predict(tlearner_rct_ut, ct_test)$predictions)))
    
    test_temp$tau_tlearner_rct <- list(tibble(tau_tlearner_rct = predict(tlearner_rct_t, test_t)$predictions - predict(tlearner_rct_ut, test_ut)$predictions))

    ### T-learner (weighted)
    tlearnerw_rct_t <- ranger(y = Y_rct[X_rctt$T == 1], 
                              x = X_rctt[X_rctt$T == 1, ], 
                              case.weights = pweights[ct_train$T == 1])
    tlearnerw_rct_ut <- ranger(y = Y_rct[X_rctt$T == 0], 
                               x = X_rctt[X_rctt$T == 0, ], 
                               case.weights = pweights[ct_train$T == 0])
      
    test_temp$y_tlearnerw_rct <- list(tibble(y_tlearnerw_rct = ifelse(ct_test$T == 1, 
                                             predict(tlearnerw_rct_t, ct_test)$predictions,
                                             predict(tlearnerw_rct_ut, ct_test)$predictions)))

    test_temp$tau_tlearnerw_rct <- list(tibble(tau_tlearnerw_rct = predict(tlearnerw_rct_t, test_t)$predictions - predict(tlearnerw_rct_ut, test_ut)$predictions))

    ## Save results
    train_res <- rbind(train_res, train_temp)
    test_res <- rbind(test_res, test_temp)
      
    print(paste("s =", s, "i =", i))
  }
}

## Combine results

test_r <- test_res %>% 
  unnest() %>%
  group_by(rep_i, rctsize_s) %>%
  summarise(flag = mean(flag),
            mean_y = mean(y),
            bias_cforestTR = mean(ate - mean(tau_cforest_tr)),
            bias_slearnerTR = mean(ate - mean(tau_slearner_tr)),
            bias_tlearnerTR = mean(ate - mean(tau_tlearner_tr)),
            bias_tclearnerTR = mean(ate - mean(tau_tclearner_tr)),
            bias_tclearnerTRmcr = mean(ate - mean(tau_tlearner_tr_mcr)),
            bias_tclearnerTRmct = mean(ate - mean(tau_tlearner_tr_mct)),
            bias_tclearnerTRmclt = mean(ate - mean(tau_tlearner_tr_mclt)),
            bias_tclearnerTRmclr = mean(ate - mean(tau_tlearner_tr_mclr)),
            bias_drlearnerTR = mean(ate - mean(tau_drlearner_tr)),
            bias_drclearnerTR = mean(ate - mean(tau_drclearner_tr)),
            bias_drlearnerTRmcr = mean(ate - mean(tau_drlearner_tr_mcr)),
            bias_drlearnerTRmct = mean(ate - mean(tau_drlearner_tr_mct)),
            bias_drlearnerTRmcfr = mean(ate - mean(tau_drlearner_tr_mcfr)),
            bias_drlearnerTRmcft = mean(ate - mean(tau_drlearner_tr_mcft)),
            bias_drlearnerTRmclr = mean(ate - mean(tau_drlearner_tr_mclr)),
            bias_drlearnerTRmclt = mean(ate - mean(tau_drlearner_tr_mclt)),
            bias_cforestRCT = mean(ate - mean(tau_cforest_rct)),
            bias_slearnerRCT = mean(ate - mean(tau_slearner_rct)),
            bias_tlearnerRCT = mean(ate - mean(tau_tlearner_rct)),
            mseXRF_cforestTR = mean((tau_xrf - tau_cforest_tr)^2),
            mseXRF_slearnerTR = mean((tau_xrf - tau_slearner_tr)^2),
            mseXRF_tlearnerTR = mean((tau_xrf - tau_tlearner_tr)^2),
            mseXRF_tclearnerTR = mean((tau_xrf - tau_tclearner_tr)^2),
            mseXRF_tclearnerTRmcr = mean((tau_xrf - tau_tlearner_tr_mcr)^2),
            mseXRF_tclearnerTRmct = mean((tau_xrf - tau_tlearner_tr_mct)^2),
            mseXRF_tclearnerTRmclt = mean((tau_xrf - tau_tlearner_tr_mclt)^2),
            mseXRF_tclearnerTRmclr = mean((tau_xrf - tau_tlearner_tr_mclr)^2),
            mseXRF_drlearnerTR = mean((tau_xrf - tau_drlearner_tr)^2),
            mseXRF_drclearnerTR = mean((tau_xrf - tau_drclearner_tr)^2),
            mseXRF_drlearnerTRmcr = mean((tau_xrf - tau_drlearner_tr_mcr)^2),
            mseXRF_drlearnerTRmct = mean((tau_xrf - tau_drlearner_tr_mct)^2),
            mseXRF_drlearnerTRmcfr = mean((tau_xrf - tau_drlearner_tr_mcfr)^2),
            mseXRF_drlearnerTRmcft = mean((tau_xrf - tau_drlearner_tr_mcft)^2),
            mseXRF_drlearnerTRmclr = mean((tau_xrf - tau_drlearner_tr_mclr)^2),
            mseXRF_drlearnerTRmclt = mean((tau_xrf - tau_drlearner_tr_mclt)^2),
            mseXRF_cforestRCT = mean((tau_xrf - tau_cforest_rct)^2),
            mseXRF_slearnerRCT = mean((tau_xrf - tau_slearner_rct)^2),
            mseXRF_tlearnerRCT = mean((tau_xrf - tau_tlearner_rct)^2),
            mseDRRF_cforestTR = mean((tau_drrf - tau_cforest_tr)^2),
            mseDRRF_slearnerTR = mean((tau_drrf - tau_slearner_tr)^2),
            mseDRRF_tlearnerTR = mean((tau_drrf - tau_tlearner_tr)^2),
            mseDRRF_tclearnerTR = mean((tau_drrf - tau_tclearner_tr)^2),
            mseDRRF_tclearnerTRmcr = mean((tau_drrf - tau_tlearner_tr_mcr)^2),
            mseDRRF_tclearnerTRmct = mean((tau_drrf - tau_tlearner_tr_mct)^2),
            mseDRRF_tclearnerTRmclt = mean((tau_drrf - tau_tlearner_tr_mclt)^2),
            mseDRRF_tclearnerTRmclr = mean((tau_drrf - tau_tlearner_tr_mclr)^2),
            mseDRRF_drlearnerTR = mean((tau_drrf - tau_drlearner_tr)^2),
            mseDRRF_drclearnerTR = mean((tau_drrf - tau_drclearner_tr)^2),
            mseDRRF_drlearnerTRmcr = mean((tau_drrf - tau_drlearner_tr_mcr)^2),
            mseDRRF_drlearnerTRmct = mean((tau_drrf - tau_drlearner_tr_mct)^2),
            mseDRRF_drlearnerTRmcfr = mean((tau_drrf - tau_drlearner_tr_mcfr)^2),
            mseDRRF_drlearnerTRmcft = mean((tau_drrf - tau_drlearner_tr_mcft)^2),
            mseDRRF_drlearnerTRmclr = mean((tau_drrf - tau_drlearner_tr_mclr)^2),
            mseDRRF_drlearnerTRmclt = mean((tau_drrf - tau_drlearner_tr_mclt)^2),
            mseDRRF_cforestRCT = mean((tau_drrf - tau_cforest_rct)^2),
            mseDRRF_slearnerRCT = mean((tau_drrf - tau_slearner_rct)^2),
            mseDRRF_tlearnerRCT = mean((tau_drrf - tau_tlearner_rct)^2),
            mseRL_cforestTR = mean((tau_rlasso - tau_cforest_tr)^2),
            mseRL_slearnerTR = mean((tau_rlasso - tau_slearner_tr)^2),
            mseRL_tlearnerTR = mean((tau_rlasso - tau_tlearner_tr)^2),
            mseRL_tclearnerTR = mean((tau_rlasso - tau_tclearner_tr)^2),
            mseRL_tclearnerTRmcr = mean((tau_rlasso - tau_tlearner_tr_mcr)^2),
            mseRL_tclearnerTRmct = mean((tau_rlasso - tau_tlearner_tr_mct)^2),
            mseRL_tclearnerTRmclt = mean((tau_rlasso - tau_tlearner_tr_mclt)^2),
            mseRL_tclearnerTRmclr = mean((tau_rlasso - tau_tlearner_tr_mclr)^2),
            mseRL_drlearnerTR = mean((tau_rlasso - tau_drlearner_tr)^2),
            mseRL_drclearnerTR = mean((tau_rlasso - tau_drclearner_tr)^2),
            mseRL_drlearnerTRmcr = mean((tau_rlasso - tau_drlearner_tr_mcr)^2),
            mseRL_drlearnerTRmct = mean((tau_rlasso - tau_drlearner_tr_mct)^2),
            mseRL_drlearnerTRmcfr = mean((tau_rlasso - tau_drlearner_tr_mcfr)^2),
            mseRL_drlearnerTRmcft = mean((tau_rlasso - tau_drlearner_tr_mcft)^2),
            mseRL_drlearnerTRmclr = mean((tau_rlasso - tau_drlearner_tr_mclr)^2),
            mseRL_drlearnerTRmclt = mean((tau_rlasso - tau_drlearner_tr_mclt)^2),
            mseRL_cforestRCT = mean((tau_rlasso - tau_cforest_rct)^2),
            mseRL_slearnerRCT = mean((tau_rlasso - tau_slearner_rct)^2),
            mseRL_tlearnerRCT = mean((tau_rlasso - tau_tlearner_rct)^2),
            mseTL_cforestTR = mean((tau_tlasso - tau_cforest_tr)^2),
            mseTL_slearnerTR = mean((tau_tlasso - tau_slearner_tr)^2),
            mseTL_tlearnerTR = mean((tau_tlasso - tau_tlearner_tr)^2),
            mseTL_tclearnerTR = mean((tau_tlasso - tau_tclearner_tr)^2),
            mseTL_tclearnerTRmcr = mean((tau_tlasso - tau_tlearner_tr_mcr)^2),
            mseTL_tclearnerTRmct = mean((tau_tlasso - tau_tlearner_tr_mct)^2),
            mseTL_tclearnerTRmclt = mean((tau_tlasso - tau_tlearner_tr_mclt)^2),
            mseTL_tclearnerTRmclr = mean((tau_tlasso - tau_tlearner_tr_mclr)^2),
            mseTL_drlearnerTR = mean((tau_tlasso - tau_drlearner_tr)^2),
            mseTL_drclearnerTR = mean((tau_tlasso - tau_drclearner_tr)^2),
            mseTL_drlearnerTRmcr = mean((tau_tlasso - tau_drlearner_tr_mcr)^2),
            mseTL_drlearnerTRmct = mean((tau_tlasso - tau_drlearner_tr_mct)^2),
            mseTL_drlearnerTRmcfr = mean((tau_tlasso - tau_drlearner_tr_mcfr)^2),
            mseTL_drlearnerTRmcft = mean((tau_tlasso - tau_drlearner_tr_mcft)^2),
            mseTL_drlearnerTRmclr = mean((tau_tlasso - tau_drlearner_tr_mclr)^2),
            mseTL_drlearnerTRmclt = mean((tau_tlasso - tau_drlearner_tr_mclt)^2),
            mseTL_cforestRCT = mean((tau_tlasso - tau_cforest_rct)^2),
            mseTL_slearnerRCT = mean((tau_tlasso - tau_slearner_rct)^2),
            mseTL_tlearnerRCT = mean((tau_tlasso - tau_tlearner_rct)^2))

test_r_long <- test_r %>%
  select(rep_i, rctsize_s, flag, bias_cforestTR:mseTL_tlearnerRCT) %>%
  pivot_longer(cols = bias_cforestTR:mseTL_tlearnerRCT,
               names_to = c("Metric", "Method"),
               names_sep = "_")

save(test_r, test_r_long,
     file = "whi_results.RData")

## Evaluation
### Bias and MSE

test_r %>%
  group_by(rctsize_s) %>%
  summarize(bias_cforestTR = mean(bias_cforestTR), 
            bias_slearnerTR = mean(bias_slearnerTR),
            bias_tlearnerTR = mean(bias_tlearnerTR), 
            bias_tclearnerTR = mean(bias_tclearnerTR),
            bias_tclearnerTRmcr = mean(bias_tclearnerTRmcr),
            bias_tclearnerTRmct = mean(bias_tclearnerTRmct),
            bias_tclearnerTRmclt = mean(bias_tclearnerTRmclt),
            bias_tclearnerTRmclr = mean(bias_tclearnerTRmclr),
            bias_drlearnerTR = mean(bias_drlearnerTR),
            bias_drclearnerTR = mean(bias_drclearnerTR),
            bias_drlearnerTRmcr = mean(bias_drlearnerTRmcr),
            bias_drlearnerTRmct = mean(bias_drlearnerTRmct),
            bias_drlearnerTRmcfr = mean(bias_drlearnerTRmcfr),
            bias_drlearnerTRmcft = mean(bias_drlearnerTRmcft),
            bias_drlearnerTRmclr = mean(bias_drlearnerTRmclr),
            bias_drlearnerTRmclt = mean(bias_drlearnerTRmclt),
            bias_cforestRCT = mean(bias_cforestRCT),
            bias_slearnerRCT = mean(bias_slearnerRCT),
            bias_tlearnerRCT = mean(bias_tlearnerRCT)) %>%
  mutate(across(bias_cforestTR:bias_tlearnerRCT, round, 3))

test_r %>%
  group_by(rctsize_s) %>%
  summarize(mse_cforestTR = mean(mseXRF_cforestTR), 
            mse_slearnerTR = mean(mseXRF_slearnerTR),
            mse_tlearnerTR = mean(mseXRF_tlearnerTR), 
            mse_tclearnerTR = mean(mseXRF_tclearnerTR), 
            mse_tclearnerTRmcr = mean(mseXRF_tclearnerTRmcr),
            mse_tclearnerTRmct = mean(mseXRF_tclearnerTRmct),
            mse_tclearnerTRmclt = mean(mseXRF_tclearnerTRmclt),
            mse_tclearnerTRmclr = mean(mseXRF_tclearnerTRmclr),
            mse_drlearnerTR = mean(mseXRF_drlearnerTR),
            mse_drclearnerTR = mean(mseXRF_drclearnerTR),
            mse_drlearnerTRmcr = mean(mseXRF_drlearnerTRmcr),
            mse_drlearnerTRmct = mean(mseXRF_drlearnerTRmct),
            mse_drlearnerTRmcfr = mean(mseXRF_drlearnerTRmcfr),
            mse_drlearnerTRmcft = mean(mseXRF_drlearnerTRmcft),
            mse_drlearnerTRmclr = mean(mseXRF_drlearnerTRmclr),
            mse_drlearnerTRmclt = mean(mseXRF_drlearnerTRmclt),
            mse_cforestRCT = mean(mseXRF_cforestRCT), 
            mse_slearnerRCT = mean(mseXRF_slearnerRCT),
            mse_tlearnerRCT = mean(mseXRF_tlearnerRCT)) %>%
  mutate(across(mse_cforestTR:mse_tlearnerRCT, round, 3))

test_r %>%
  group_by(rctsize_s) %>%
  summarize(mse_cforestTR = mean(mseDRRF_cforestTR), 
            mse_slearnerTR = mean(mseDRRF_slearnerTR),
            mse_tlearnerTR = mean(mseDRRF_tlearnerTR), 
            mse_tclearnerTR = mean(mseDRRF_tclearnerTR), 
            mse_tclearnerTRmcr = mean(mseDRRF_tclearnerTRmcr),
            mse_tclearnerTRmct = mean(mseDRRF_tclearnerTRmct),
            mse_tclearnerTRmclt = mean(mseDRRF_tclearnerTRmclt),
            mse_tclearnerTRmclr = mean(mseDRRF_tclearnerTRmclr),
            mse_drlearnerTR = mean(mseDRRF_drlearnerTR),
            mse_drclearnerTR = mean(mseDRRF_drclearnerTR),
            mse_drlearnerTRmcr = mean(mseDRRF_drlearnerTRmcr),
            mse_drlearnerTRmct = mean(mseDRRF_drlearnerTRmct),
            mse_drlearnerTRmcfr = mean(mseDRRF_drlearnerTRmcfr),
            mse_drlearnerTRmcft = mean(mseDRRF_drlearnerTRmcft),
            mse_drlearnerTRmclr = mean(mseDRRF_drlearnerTRmclr),
            mse_drlearnerTRmclt = mean(mseDRRF_drlearnerTRmclt),
            mse_cforestRCT = mean(mseDRRF_cforestRCT), 
            mse_slearnerRCT = mean(mseDRRF_slearnerRCT),
            mse_tlearnerRCT = mean(mseDRRF_tlearnerRCT)) %>%
  mutate(across(mse_cforestTR:mse_tlearnerRCT, round, 3))

test_r %>%
  group_by(rctsize_s) %>%
  summarize(mse_cforestTR = mean(mseRL_cforestTR), 
            mse_slearnerTR = mean(mseRL_slearnerTR),
            mse_tlearnerTR = mean(mseRL_tlearnerTR), 
            mse_tclearnerTR = mean(mseRL_tclearnerTR), 
            mse_tclearnerTRmcr = mean(mseRL_tclearnerTRmcr),
            mse_tclearnerTRmct = mean(mseRL_tclearnerTRmct),
            mse_tclearnerTRmclt = mean(mseRL_tclearnerTRmclt),
            mse_tclearnerTRmclr = mean(mseRL_tclearnerTRmclr),
            mse_drlearnerTR = mean(mseRL_drlearnerTR),
            mse_drclearnerTR = mean(mseRL_drclearnerTR),
            mse_drlearnerTRmcr = mean(mseRL_drlearnerTRmcr),
            mse_drlearnerTRmct = mean(mseRL_drlearnerTRmct),
            mse_drlearnerTRmcfr = mean(mseRL_drlearnerTRmcfr),
            mse_drlearnerTRmcft = mean(mseRL_drlearnerTRmcft),
            mse_drlearnerTRmclr = mean(mseRL_drlearnerTRmclr),
            mse_drlearnerTRmclt = mean(mseRL_drlearnerTRmclt),
            mse_cforestRCT = mean(mseRL_cforestRCT), 
            mse_slearnerRCT = mean(mseRL_slearnerRCT),
            mse_tlearnerRCT = mean(mseRL_tlearnerRCT)) %>%
  mutate(across(mse_cforestTR:mse_tlearnerRCT, round, 3))

test_r %>%
  group_by(rctsize_s) %>%
  summarize(mse_cforestTR = mean(mseTL_cforestTR), 
            mse_slearnerTR = mean(mseTL_slearnerTR),
            mse_tlearnerTR = mean(mseTL_tlearnerTR), 
            mse_tclearnerTR = mean(mseTL_tclearnerTR),
            mse_tclearnerTRmcr = mean(mseTL_tclearnerTRmcr),
            mse_tclearnerTRmct = mean(mseTL_tclearnerTRmct),
            mse_tclearnerTRmclt = mean(mseTL_tclearnerTRmclt),
            mse_tclearnerTRmclr = mean(mseTL_tclearnerTRmclr),
            mse_drlearnerTR = mean(mseTL_drlearnerTR),
            mse_drclearnerTR = mean(mseTL_drclearnerTR),
            mse_drlearnerTRmcr = mean(mseTL_drlearnerTRmcr),
            mse_drlearnerTRmct = mean(mseTL_drlearnerTRmct),
            mse_drlearnerTRmcfr = mean(mseTL_drlearnerTRmcfr),
            mse_drlearnerTRmcft = mean(mseTL_drlearnerTRmcft),
            mse_drlearnerTRmclr = mean(mseTL_drlearnerTRmclr),
            mse_drlearnerTRmclt = mean(mseTL_drlearnerTRmclt),
            mse_cforestRCT = mean(mseTL_cforestRCT), 
            mse_slearnerRCT = mean(mseTL_slearnerRCT),
            mse_tlearnerRCT = mean(mseTL_tlearnerRCT)) %>%
  mutate(across(mse_cforestTR:mse_tlearnerRCT, round, 3))

### Plots

test_r_long$Method <- recode_factor(test_r_long$Method, 
                                    "cforestTR" = "CForest-OS", 
                                    "slearnerTR" = "S-learner-OS",
                                    "drlearnerTR" = "drlearnerTR",
                                    "drclearnerTR" = "DR-learner-OS",
                                    "drlearnerTRmcfr" = "DR-learner-MC-Ridge",
                                    "drlearnerTRmcft" = "DR-learner-MC-Tree",
                                    "tlearnerTR" = "tlearnerTR",
                                    "tclearnerTR" = "T-learner-OS", 
                                    "tclearnerTRmcr" = "T-learner-MC-Ridge", 
                                    "tclearnerTRmct" = "T-learner-MC-Tree",
                                    "cforestRCT" = "CForest-CT", 
                                    "slearnerRCT" = "S-learner-CT", 
                                    "tlearnerRCT" = "T-learner-CT")
test_r_long$Metric <- recode_factor(test_r_long$Metric, 
                                    "mseRL" = "RL-NET",
                                    "mseTL" = "TL-NET",
                                    "mseXRF" = "XL-RF",
                                    "mseDRRF" = "DR-RF")

test_r_long %>%
  filter(rctsize_s < 1750) %>%
  mutate(method = fct_rev(Method)) %>%
 # filter(flag == 0) %>%
  filter(!Method %in% c("tlearnerTR", "tclearnerTRmclr", "tclearnerTRmclt")) %>%
  filter(!Method %in% c("drlearnerTR", "drlearnerTRmcr", "drlearnerTRmct", "drlearnerTRmclr", "drlearnerTRmclt")) %>%
  filter(Metric == "bias") %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "Bias (ATE)") +
  ylim(-4, 5) +
  guides(color = guide_legend(reverse = TRUE)) +
  facet_grid(rows = vars(rctsize_s)) +
  scale_colour_manual(values = RColorBrewer::brewer.pal(12, "Paired")) +
  coord_flip() +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.title = element_blank(),
        text = element_text(size = 13))

ggsave("whi_test_ate-bias-box.pdf", width = 7, height = 7)

test_r_long %>%
  filter(rctsize_s < 1750) %>%
  mutate(method = fct_rev(Method)) %>%
  # mutate(type = fct_rev(ifelse(grepl("CT", method), "CT", "OS"))) %>%
  # filter(flag == 0) %>%
  filter(!Method %in% c("tlearnerTR", "tclearnerTRmclr", "tclearnerTRmclt")) %>%
  filter(!Method %in% c("drlearnerTR", "drlearnerTRmcr", "drlearnerTRmct", "drlearnerTRmclr", "drlearnerTRmclt")) %>%
  filter(Metric %in% c ("RL-NET", "TL-NET", "XL-RF")) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "MSE (CATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  facet_grid(rows = vars(rctsize_s), cols = vars(Metric), scales = "free") +
  # facet_wrap(type ~ rctsize_s + Metric,
  #            ncol = 3, strip.position = "left", scales = "free") +
  scale_colour_manual(values = RColorBrewer::brewer.pal(12, "Paired")) +
  coord_flip(ylim = c(0, 75)) +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.title = element_blank(),
        text = element_text(size = 13))

ggsave("whi_test_cate-mse-box.pdf", width = 10, height = 7)

test_r_long %>%
  filter(rctsize_s < 1750) %>%
  mutate(method = fct_rev(Method)) %>%
  # filter(flag == 0) %>%
  filter(!Method %in% c("tlearnerTR", "tclearnerTRmclr", "tclearnerTRmclt",
                        "drlearnerTR", "drlearnerTRmcr", "drlearnerTRmct", 
                        "drlearnerTRmclr", "drlearnerTRmclt", 
                        "CForest-CT", "S-learner-CT", "T-learner-CT")) %>%
  filter(Metric %in% c ("RL-NET", "TL-NET", "XL-RF")) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "MSE (CATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  facet_grid(rows = vars(rctsize_s), cols = vars(Metric), scales = "free") +
  scale_colour_manual(values = c(RColorBrewer::brewer.pal(12, "Paired")[4:12])) +
  coord_flip() + # coord_flip(ylim = c(0, 40))
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.title = element_blank(),
        text = element_text(size = 13))

ggsave("whi_test_cate-mse-box-OS.pdf", width = 10, height = 5)

test_r_long %>%
  filter(rctsize_s < 1750) %>%
  mutate(method = fct_rev(Method)) %>%
  # filter(flag == 0) %>%
  filter(Method %in% c("CForest-CT", "S-learner-CT", "T-learner-CT")) %>%
  filter(Metric %in% c ("RL-NET", "TL-NET", "XL-RF")) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "MSE (CATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  facet_grid(rows = vars(rctsize_s), cols = vars(Metric), scales = "free") +
  scale_colour_manual(values = RColorBrewer::brewer.pal(12, "Paired")[1:3]) +
  coord_flip() + # coord_flip(ylim = c(0, 75))
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.title = element_blank(),
        text = element_text(size = 13))

ggsave("whi_test_cate-mse-box-CT.pdf", width = 10, height = 4)

## Aggregated bar plots

data_summary <- function(data, varname, groupnames){
  summary_func <- function(x, col){
    c(mean = mean(x[[col]], na.rm=T), sd = sd(x[[col]], na.rm=T))
  }
  data_sum <- ddply(data, groupnames, .fun = summary_func, varname)
  data_sum <- rename(data_sum, c("mean" = varname))
  return(data_sum)
}

d_sum <- data_summary(test_r_long, varname = "value", 
                      groupnames = c("rctsize_s", "Metric", "Method"))

d_sum %>%
  filter(rctsize_s < 1250) %>%
  mutate(rctsize_s = factor(rctsize_s),
         Method = recode_factor(Method, 
                                "cforestTR" = "CForest-OS", 
                                "tclearnerTR" = "T-learner-OS", 
                                "tclearnerTRmcr" = "T-learner-MC-Ridge",
                                "drclearnerTR" = "DR-learner-OS", 
                                "drlearnerTRmcfr" = "DR-learner-MC-Ridge"),
         Metric = recode_factor(Metric, 
                                "mseRL" = "RL-NET",
                                "mseTL" = "TL-NET",
                                "mseXRF" = "XL-RF")) %>%
  filter(!Metric %in% c("bias", "mseDRRF")) %>%
  filter(Method %in% c("CForest-OS", "T-learner-OS", "T-learner-MC-Ridge",
                       "DR-learner-OS", "DR-learner-MC-Ridge")) %>%
  ggplot(aes(x = rctsize_s, y = value, fill = Method)) + 
  geom_bar(stat = "identity", color = "black", position = position_dodge()) +
  geom_errorbar(aes(ymin = value-sd, ymax = value+sd), 
                linewidth = 0.2, width = .2, position = position_dodge(.9)) +
  labs(y = "MSE (CATE)", x = "RCT size") +
  # scale_fill_brewer(palette = "Paired") +
  scale_fill_manual(values = c(RColorBrewer::brewer.pal(8, "Paired")[7],
                               RColorBrewer::brewer.pal(6, "Paired")[1:5])) +
  facet_grid(cols = vars(Metric), scales = "free") +
  theme(legend.title = element_blank(),
        text = element_text(size = 10))

ggsave("whi_test_cate-mse.pdf", width = 10, height = 3)

d_sum %>%
  filter(rctsize_s < 1250) %>%
  mutate(rctsize_s = factor(rctsize_s),
         value = abs(value),
         Method = recode_factor(Method, 
                                "cforestTR" = "CForest-OS", 
                                "tclearnerTR" = "T-learner-OS", 
                                "tclearnerTRmcr" = "T-learner-MC-Ridge",
                                "drclearnerTR" = "DR-learner-OS", 
                                "drlearnerTRmcfr" = "DR-learner-MC-Ridge",
                                "cforestRCT" = "CForest-CT"),
         Metric = recode_factor(Metric, 
                                "bias" = "Bias (ATE)",
                                "mseRL" = "MSE (CATE)")) %>%
  filter(!Metric %in% c("mseTL", "mseXRF", "mseDRRF")) %>%
  filter(Method %in% c("CForest-OS", "T-learner-OS", "T-learner-MC-Ridge",
                       "DR-learner-OS", "DR-learner-MC-Ridge", "CForest-CT")) %>%
  ggplot(aes(x = rctsize_s, y = value, fill = Method)) + 
  geom_bar(stat = "identity", color = "black", position = position_dodge()) +
  geom_errorbar(aes(ymin = value-sd, ymax = value+sd), 
                linewidth = 0.2, width = .2, position = position_dodge(.9)) +
  labs(y = "", x = "CT size") +
  facet_wrap(Metric ~ ., scales = "free") +
  #scale_fill_brewer(palette = "Paired") +
  scale_fill_manual(values = c(RColorBrewer::brewer.pal(8, "Paired")[7],
                               RColorBrewer::brewer.pal(6, "Paired")[1:6])) +
  facetted_pos_scales(y = list(Metric == "Bias (ATE)" ~ scale_y_continuous(limits = c(0, 3.75)),
                               Metric == "MSE (CATE)" ~ scale_y_continuous(limits = c(0, 22.25)))) +
  theme(legend.title = element_blank(),
        text = element_text(size = 10))

ggsave("whi_test_cate-bias-mse.pdf", width = 8.5, height = 2.75)
