# Multi-CATE simulations
# Setup: External shift between training and test data
# Continuous outcome

# Scenarios as defined in setups.R:

# s1b_1a
# dim: 10 (normal)
# pcscore: osSparse1Beta
# mu0: fullLocallyLinear
# tau: sparseLinearWeak
# external shift: SparseNonLinear
# audit sample from source

# s1b_1b
# dim: 10 (normal)
# pcscore: osSparse2Linear
# mu0: sparseLinearWeak
# tau: fullLinearWeak
# external shift: SparseLogitLinear2
# audit sample from source

# s1b_1c
# dim: 10 (normal)
# pscore: osSparse1Beta
# mu0: sparseLinearWeak
# tau: fullLinearWeak
# external shift: SparseNonLinear
# audit sample from source

# s1b_2a
# dim: 10 (normal)
# pcscore: osSparse1Beta
# mu0: fullLocallyLinear
# tau: sparseLinearWeak
# external shift: SparseNonLinear
# audit sample from target

# s1b_2b
# dim: 10 (normal)
# pcscore: osSparse2Linear
# mu0: sparseLinearWeak
# tau: fullLinearWeak
# external shift: SparseLogitLinear2
# audit sample from target

library(tidyverse)
library(ranger)
library(grf)
library(mcboost)
library(mlr3learners)
library(kableExtra)
library(ggh4x)
#library(causalToolbox)
source("setups_num.R")
source("dr_learner.R")
source("KL.R")

## Prepare loop

train_res <- tibble()
train_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA, flag = NA, 
                     ps = list(tibble(ps = NA)),
                     shift = list(tibble(shift = NA)),
                     KL = list(tibble(KL = NA)),
                     y = list(tibble(y = NA)),
                     ate = NA, tau = list(tibble(tau = NA)),
                     y_cforest = list(tibble(y_cforest = NA)),
                     y_cforestw = list(tibble(y_cforestw = NA)),
                     y_slearner_rf = list(tibble(y_slearner_rf = NA)),
                     y_slearnerw_rf = list(tibble(y_slearnerw_rf = NA)),
                     y_tlearner_rf = list(tibble(y_tlearner_rf = NA)),
                     y_tlearnerw_rf = list(tibble(y_tlearnerw_rf = NA)),
                     yp_tlearner_rf_mcr = list(tibble(yp_tlearner_rf_mcr = NA)),
                     y_tlearner_rf_mcr = list(tibble(y_tlearner_rf_mcr = NA)),
                     yp_tlearner_rf_mct = list(tibble(yp_tlearner_rf_mct = NA)),
                     y_tlearner_rf_mct = list(tibble(y_tlearner_rf_mct = NA)),
                     yp_tlearner_rf_mclr = list(tibble(yp_tlearner_rf_mclr = NA)),
                     y_tlearner_rf_mclr = list(tibble(y_tlearner_rf_mclr = NA)),
                     yp_tlearner_rf_mclt = list(tibble(yp_tlearner_rf_mclt = NA)),
                     y_tlearner_rf_mclt = list(tibble(y_tlearner_rf_mclt = NA)),
                     y_tlearner_crf = list(tibble(y_tlearner_crf = NA)),
                     y_cforest_audit = list(tibble(y_cforest_audit = NA)),
                     y_slearner_audit = list(tibble(y_slearner_audit = NA)),
                     y_tlearner_audit = list(tibble(y_tlearner_audit = NA)),
                     tau_cforest = list(tibble(tau_cforest = NA)),
                     tau_cforestw = list(tibble(tau_cforestw = NA)),
                     tau_slearner_rf = list(tibble(tau_slearner_rf = NA)),
                     tau_slearnerw_rf = list(tibble(tau_slearnerw_rf = NA)),
                     tau_drlearner = list(tibble(tau_drlearner = NA)),
                     tau_drlearner_mcr = list(tibble(tau_drlearner_mcr = NA)),
                     tau_drlearner_mct = list(tibble(tau_drlearner_mct = NA)),
                     tau_drlearner_mcfr = list(tibble(tau_drlearner_mcfr = NA)),
                     tau_drlearner_mcft = list(tibble(tau_drlearner_mcft = NA)),
                     tau_drlearner_mclr = list(tibble(tau_drlearner_mclr = NA)),
                     tau_drlearner_mclt = list(tibble(tau_drlearner_mclt = NA)),
                     tau_tlearner_rf = list(tibble(tau_tlearner_rf = NA)),
                     tau_tlearnerw_rf = list(tibble(tau_tlearnerw_rf = NA)),
                     tau_tlearner_rf_mcr = list(tibble(tau_tlearner_rf_mcr = NA)),
                     tau_tlearner_rf_mct = list(tibble(tau_tlearner_rf_mct = NA)),
                     tau_tlearner_rf_mclr = list(tibble(tau_tlearner_rf_mclr = NA)),
                     tau_tlearner_rf_mclt = list(tibble(tau_tlearner_rf_mclt = NA)),
                     tau_tlearner_crf = list(tibble(tau_tlearner_crf = NA)),
                     tau_cforest_audit = list(tibble(tau_cforest_audit = NA)),
                     tau_slearner_audit = list(tibble(tau_slearner_audit = NA)),
                     tau_tlearner_audit = list(tibble(tau_tlearner_audit = NA)))

test_res <- tibble()
test_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA, flag = NA, 
                    ps = list(tibble(ps = NA)),
                    shift = list(tibble(shift = NA)),
                    KL = list(tibble(KL = NA)),
                    y = list(tibble(y = NA)),
                    ate = NA, tau = list(tibble(tau = NA)),
                    y_cforest = list(tibble(y_cforest = NA)),
                    y_cforestw = list(tibble(y_cforestw = NA)),
                    y_slearner_rf = list(tibble(y_slearner_rf = NA)),
                    y_slearnerw_rf = list(tibble(y_slearnerw_rf = NA)),
                    y_tlearner_rf = list(tibble(y_tlearner_rf = NA)),
                    y_tlearnerw_rf = list(tibble(y_tlearnerw_rf = NA)),
                    yp_tlearner_rf_mcr = list(tibble(yp_tlearner_rf_mcr = NA)),
                    y_tlearner_rf_mcr = list(tibble(y_tlearner_rf_mcr = NA)),
                    yp_tlearner_rf_mct = list(tibble(yp_tlearner_rf_mct = NA)),
                    y_tlearner_rf_mct = list(tibble(y_tlearner_rf_mct = NA)),
                    yp_tlearner_rf_mclr = list(tibble(yp_tlearner_rf_mclr = NA)),
                    y_tlearner_rf_mclr = list(tibble(y_tlearner_rf_mclr = NA)),
                    yp_tlearner_rf_mclt = list(tibble(yp_tlearner_rf_mclt = NA)),
                    y_tlearner_rf_mclt = list(tibble(y_tlearner_rf_mclt = NA)),
                    y_tlearner_crf = list(tibble(y_tlearner_crf = NA)),
                    y_cforest_audit = list(tibble(y_cforest_audit = NA)),
                    y_slearner_audit = list(tibble(y_slearner_audit = NA)),
                    y_tlearner_audit = list(tibble(y_tlearner_audit = NA)),
                    tau_cforest = list(tibble(tau_cforest = NA)),
                    tau_cforestw = list(tibble(tau_cforestw = NA)),
                    tau_slearner_rf = list(tibble(tau_slearner_rf = NA)),
                    tau_slearnerw_rf = list(tibble(tau_slearnerw_rf = NA)),
                    tau_drlearner = list(tibble(tau_drlearner = NA)),
                    tau_drlearner_mcr = list(tibble(tau_drlearner_mcr = NA)),
                    tau_drlearner_mct = list(tibble(tau_drlearner_mct = NA)),
                    tau_drlearner_mcfr = list(tibble(tau_drlearner_mcfr = NA)),
                    tau_drlearner_mcft = list(tibble(tau_drlearner_mcft = NA)),
                    tau_drlearner_mclr = list(tibble(tau_drlearner_mclr = NA)),
                    tau_drlearner_mclt = list(tibble(tau_drlearner_mclt = NA)),
                    tau_tlearner_rf = list(tibble(tau_tlearner_rf = NA)),
                    tau_tlearnerw_rf = list(tibble(tau_tlearnerw_rf = NA)),
                    tau_tlearner_rf_mcr = list(tibble(tau_tlearner_rf_mcr = NA)),
                    tau_tlearner_rf_mct = list(tibble(tau_tlearner_rf_mct = NA)),
                    tau_tlearner_rf_mclr = list(tibble(tau_tlearner_rf_mclr = NA)),
                    tau_tlearner_rf_mclt = list(tibble(tau_tlearner_rf_mclt = NA)),
                    tau_tlearner_crf = list(tibble(tau_tlearner_crf = NA)),
                    tau_cforest_audit = list(tibble(tau_cforest_audit = NA)),
                    tau_slearner_audit = list(tibble(tau_slearner_audit = NA)),
                    tau_tlearner_audit = list(tibble(tau_tlearner_audit = NA)))

s_range <- seq(500, 5000, by = 1500) # n obs training set
e_range <- seq(0, 3, by = 0.25) # shift amplifier
n_reps <- 25 # n repetitions

scale <- function(x, label){
  (x - min(label))/(max(label) - min(label))
}

rev_scale <- function(preds, label){
  return(preds*(max(label) - min(label)) + min(label))
}

ridge <- LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 0, s = 1))
tree <- LearnerAuditorFitter$new(lrn("regr.rpart", maxdepth = 3))

## Simulation

## Generate population
init_sim <- simulate_causal_experiment(ntrain = 100000, # n obs population
                                       dim = 10, # n covars
                                       alpha = .1, # corr
                                       feat_distribution = "normal", 
                                       pscore = "osSparse1Beta", 
                                       mu0 = "sparseLinearWeak",
                                       tau = "fullLinearWeak",
                                       shiftfun = "SparseNonLinear") 

pop <- tibble(init_sim$feat_tr, 
              ps = init_sim$Wp_tr,
              T = init_sim$W_tr, 
              tau = init_sim$tau_tr, 
              Y = init_sim$Yobs_tr,
              shift = init_sim$shift_tr,
              shiftw = init_sim$shiftw_tr)

for(e in e_range) {
  
  ## Set external shift
  pop$shift_s <- pop$shiftw^e # weights source
  pop$shift_t <- pop$shiftw^-e # weights target
  
  train_temp$shift_e <- e
  test_temp$shift_e <- e
  
  for(s in s_range) {
    
    ## Set training set size
    train_size <- s
    test_size <- 5000
    audit_size <- 500
    
    train_temp$trsize_s <- s
    test_temp$trsize_s <- s
    
    for(i in 1:n_reps) {
      
      train_temp$rep_i <- i
      test_temp$rep_i <- i

      ## Sample from population
      train <- slice_sample(pop, n = train_size)
      audit <- slice_sample(pop, n = audit_size) # audit in source/target
      test <- slice_sample(pop, n = test_size, weight_by = shift_s)

      train_temp$ps <- list(tibble(ps = train$ps)) # treatment propensities
      test_temp$ps <- list(tibble(ps = test$ps))
      train_temp$shift <- list(tibble(shift = train$shift)) # source sample propensities
      test_temp$shift <- list(tibble(shift = test$shift))
      train_temp$y <- list(tibble(y = train$Y)) # true Y
      test_temp$y <- list(tibble(y = test$Y))
      train_temp$ate <- mean(train$tau) # true ATE
      test_temp$ate <- mean(test$tau)
      train_temp$tau <- list(tibble(tau = train$tau)) # true tau
      test_temp$tau <- list(tibble(tau = test$tau))

      train <- select(train, Y, T, x1:x10) # train data  
      X_traint <- select(train, T, x1:x10) 
      Y_train <- train$Y
      Y_trains <- scale(Y_train, label = Y_train*3) # increase range of y
      T_train <- train$T
      X_auditt <- select(audit, T, x1:x10) # audit data
      Y_audit <- audit$Y
      Y_audits <- scale(Y_audit, label = Y_train*3)
      T_audit <- audit$T
      train_temp$flag <- ifelse(min(Y_audits) < 0 | max(Y_audits) > 1, 1, 0) # outside [0, 1]?
      test_temp$flag <- ifelse(min(Y_audits) < 0 | max(Y_audits) > 1, 1, 0)
      Y_audits <- ifelse(Y_audits < 0, 0, Y_audits) # clip to [0, 1]
      Y_audits <- ifelse(Y_audits > 1, 1, Y_audits)
      T_audit <- audit$T
      X_test <- select(test, x1:x10) # test data
      
      train_ut <- data.frame(X_traint[-T], T = 0) # Fix treated and untreated
      train_t <- data.frame(X_traint[-T], T = 1)
      test_ut <- data.frame(X_test, T = 0)
      test_t <- data.frame(X_test, T = 1)

      ## KL divergence 
      
      Mref <- colMeans(X_test)
      Sref <- cov(X_test)
      Mtrain <- colMeans(X_traint[,-1])
      Strain <- cov(X_traint[,-1])
      train_temp$KL <- KLdiv(Mtrain, Mref, Strain, Sref, symmetric = F)
      test_temp$KL <- KLdiv(Mtrain, Mref, Strain, Sref, symmetric = F)
      
      ## Propensity score model - Train vs Target
      
      stacked <- bind_rows(X_traint[-T], X_test, .id = "source")
      stacked$source <- as.numeric(ifelse(stacked$source == 1, 1, 0))
      psm <- glm(source ~ ., family = binomial, data = stacked)
      pscores <- predict(psm, newdata = X_traint[-T], type = "response")
      pweights <- (1 - pscores) / pscores
      
      ## Train models  (w. train data)
      ### Causal Forest
      cforest <- causal_forest(X_traint[-T], Y_train, T_train)

      train_temp$tau_cforest <- list(tibble(tau_cforest = predict(cforest)$predictions))
      test_temp$tau_cforest <- list(tibble(tau_cforest = predict(cforest, X_test)$predictions))

      ### Causal Forest (weighted)
      cforestw <- causal_forest(X_traint[-T], Y_train, T_train, sample.weights = pweights)
      
      train_temp$tau_cforestw <- list(tibble(tau_cforestw = predict(cforestw)$predictions))
      test_temp$tau_cforestw <- list(tibble(tau_cforestw = predict(cforestw, X_test)$predictions))
      
      ### S-learner
      slearner_rf <- ranger(y = Y_trains, x = X_traint)

      yp_sl_t <- predict(slearner_rf, train_t)$predictions
      yp_sl_ut <- predict(slearner_rf, train_ut)$predictions
      y_sl_t <- rev_scale(yp_sl_t, label = Y_train*3)
      y_sl_ut <- rev_scale(yp_sl_ut, label = Y_train*3)
      
      train_temp$y_slearner_rf <- list(tibble(y_slearner_rf = rev_scale(predict(slearner_rf, train)$predictions, label = Y_train*3)))
      train_temp$tau_slearner_rf <- list(tibble(tau_slearner_rf = y_sl_t - y_sl_ut))
      
      yp_sl_t <- predict(slearner_rf, test_t)$predictions
      yp_sl_ut <- predict(slearner_rf, test_ut)$predictions
      y_sl_t <- rev_scale(yp_sl_t, label = Y_train*3)
      y_sl_ut <- rev_scale(yp_sl_ut, label = Y_train*3)
      
      test_temp$y_slearner_rf <- list(tibble(y_slearner_rf = rev_scale(predict(slearner_rf, test)$predictions, label = Y_train*3)))
      test_temp$tau_slearner_rf <- list(tibble(tau_slearner_rf = y_sl_t - y_sl_ut))
      
      ### S-learner (weighted)
      slearnerw_rf <- ranger(y = Y_trains, x = X_traint, case.weights = pweights)
      
      yp_sl_t <- predict(slearnerw_rf, train_t)$predictions
      yp_sl_ut <- predict(slearnerw_rf, train_ut)$predictions
      y_sl_t <- rev_scale(yp_sl_t, label = Y_train*3)
      y_sl_ut <- rev_scale(yp_sl_ut, label = Y_train*3)
      
      train_temp$y_slearnerw_rf <- list(tibble(y_slearnerw_rf = rev_scale(predict(slearnerw_rf, train)$predictions, label = Y_train*3)))
      train_temp$tau_slearnerw_rf <- list(tibble(tau_slearnerw_rf = y_sl_t - y_sl_ut))
      
      yp_sl_t <- predict(slearnerw_rf, test_t)$predictions
      yp_sl_ut <- predict(slearnerw_rf, test_ut)$predictions
      y_sl_t <- rev_scale(yp_sl_t, label = Y_train*3)
      y_sl_ut <- rev_scale(yp_sl_ut, label = Y_train*3)
      
      test_temp$y_slearnerw_rf <- list(tibble(y_slearnerw_rf = rev_scale(predict(slearnerw_rf, test)$predictions, label = Y_train*3)))
      test_temp$tau_slearnerw_rf <- list(tibble(tau_slearnerw_rf = y_sl_t - y_sl_ut))
      
      ### DR-learner
      drlearner <- dr_learner(X_traint[-T], Y_train, T_train, test)
      
      train_temp$tau_drlearner <- list(tibble(tau_drlearner = drlearner$tau.hat))
      test_temp$tau_drlearner <- list(tibble(tau_drlearner = drlearner$tau.new))
      
      ### DR-learner + MCBoost (ridge)
      drlearnermcr <- dr_learnermc(X_traint[-T], X_traint, X_auditt, 
                                   Y_train, Y_trains, Y_audits, 
                                   T_train, 
                                   test)
      
      train_temp$tau_drlearner_mcr <- list(tibble(tau_drlearner_mcr = drlearnermcr$tau.hat))
      test_temp$tau_drlearner_mcr <- list(tibble(tau_drlearner_mcr = drlearnermcr$tau.new))

      ### DR-learner + MCBoost (tree)
      drlearnermct <- dr_learnermc(X_traint[-T], X_traint, X_auditt, 
                                   Y_train, Y_trains, Y_audits, 
                                   T_train, 
                                   test, auditor = "TreeAuditorFitter")
      
      train_temp$tau_drlearner_mct <- list(tibble(tau_drlearner_mct = drlearnermct$tau.hat))
      test_temp$tau_drlearner_mct <- list(tibble(tau_drlearner_mct = drlearnermct$tau.new))
      
      ### DR-learner + MCBoost (ridge)
      drlearnermcfr <- dr_learnermc2(X_traint[-T], X_traint, X_auditt, 
                                     Y_train, Y_trains, Y_audit, Y_audits,
                                     T_train, T_audit, 
                                     test, eta = 0.1)
      
      train_temp$tau_drlearner_mcfr <- list(tibble(tau_drlearner_mcfr = drlearnermcfr$tau.hat))
      test_temp$tau_drlearner_mcfr <- list(tibble(tau_drlearner_mcfr = drlearnermcfr$tau.new))
      
      ### DR-learner + MCBoost (tree) 
      drlearnermcft <- dr_learnermc2(X_traint[-T], X_traint, X_auditt, 
                                     Y_train, Y_trains, Y_audit, Y_audits, 
                                     T_train, T_audit, 
                                     test, eta = 0.1, auditor = "TreeAuditorFitter")
      
      train_temp$tau_drlearner_mcft <- list(tibble(tau_drlearner_mcft = drlearnermcft$tau.hat))
      test_temp$tau_drlearner_mcft <- list(tibble(tau_drlearner_mcft = drlearnermcft$tau.new))
      
      ### DR-learner + MCBoost (ridge)
      drlearnermclr <- dr_learnermc3(X_traint[-T], X_traint, X_auditt, 
                                     Y_train, Y_audit, 
                                     T_train, T_audit, 
                                     test, eta = 0.1)
      
      train_temp$tau_drlearner_mclr <- list(tibble(tau_drlearner_mclr = drlearnermclr$tau.hat))
      test_temp$tau_drlearner_mclr <- list(tibble(tau_drlearner_mclr = drlearnermclr$tau.new))
  
      ### DR-learner + MCBoost (tree) 
      drlearnermclt <- dr_learnermc3(X_traint[-T], X_traint, X_auditt, 
                                     Y_train, Y_audit, 
                                     T_train, T_audit, 
                                     test, eta = 0.1, auditor = "TreeAuditorFitter")
      
      train_temp$tau_drlearner_mclt <- list(tibble(tau_drlearner_mclt = drlearnermclt$tau.hat))
      test_temp$tau_drlearner_mclt <- list(tibble(tau_drlearner_mclt = drlearnermclt$tau.new))
      
      ### T-learner
      tlearner_t <- ranger(y = Y_trains[X_traint$T == 1], 
                           x = X_traint[X_traint$T == 1, ])
      tlearner_ut <- ranger(y = Y_trains[X_traint$T == 0], 
                            x = X_traint[X_traint$T == 0, ])
      
      yp_tl_t <- predict(tlearner_t, train_t)$predictions
      yp_tl_ut <- predict(tlearner_ut, train_ut)$predictions
      y_tl_t <- rev_scale(yp_tl_t, label = Y_train*3)
      y_tl_ut <- rev_scale(yp_tl_ut, label = Y_train*3)
      train_temp$y_tlearner_rf <- list(tibble(y_tlearner_rf = ifelse(train$T == 1, 
                                              rev_scale(predict(tlearner_t, train)$predictions, label = Y_train*3), 
                                              rev_scale(predict(tlearner_ut, train)$predictions, label = Y_train*3))))
      train_temp$tau_tlearner_rf <- list(tibble(tau_tlearner_rf = y_tl_t - y_tl_ut))
      
      yp_tl_t <- predict(tlearner_t, test_t)$predictions
      yp_tl_ut <- predict(tlearner_ut, test_ut)$predictions
      y_tl_t <- rev_scale(yp_tl_t, label = Y_train*3)
      y_tl_ut <- rev_scale(yp_tl_ut, label = Y_train*3)
      test_temp$y_tlearner_rf <- list(tibble(y_tlearner_rf = ifelse(test$T == 1, 
                                             rev_scale(predict(tlearner_t, test)$predictions, label = Y_train*3), 
                                             rev_scale(predict(tlearner_ut, test)$predictions, label = Y_train*3))))
      test_temp$tau_tlearner_rf <- list(tibble(tau_tlearner_rf = y_tl_t - y_tl_ut))
      
      ### T-learner (weighted)
      tlearnerw_t <- ranger(y = Y_trains[X_traint$T == 1], 
                            x = X_traint[X_traint$T == 1, ],
                            case.weights = pweights[X_traint$T == 1])
      tlearnerw_ut <- ranger(y = Y_trains[X_traint$T == 0], 
                             x = X_traint[X_traint$T == 0, ],
                             case.weights = pweights[X_traint$T == 0])

      yp_tl_t <- predict(tlearnerw_t, train_t)$predictions
      yp_tl_ut <- predict(tlearnerw_ut, train_ut)$predictions
      y_tl_t <- rev_scale(yp_tl_t, label = Y_train*3)
      y_tl_ut <- rev_scale(yp_tl_ut, label = Y_train*3)
      train_temp$y_tlearnerw_rf <- list(tibble(y_tlearnerw_rf = ifelse(train$T == 1, 
                                               rev_scale(predict(tlearnerw_t, train)$predictions, label = Y_train*3), 
                                               rev_scale(predict(tlearnerw_ut, train)$predictions, label = Y_train*3))))
      train_temp$tau_tlearnerw_rf <- list(tibble(tau_tlearnerw_rf = y_tl_t - y_tl_ut))
      
      yp_tl_t <- predict(tlearnerw_t, test_t)$predictions
      yp_tl_ut <- predict(tlearnerw_ut, test_ut)$predictions
      y_tl_t <- rev_scale(yp_tl_t, label = Y_train*3)
      y_tl_ut <- rev_scale(yp_tl_ut, label = Y_train*3)
      test_temp$y_tlearnerw_rf <- list(tibble(y_tlearnerw_rf = ifelse(test$T == 1, 
                                              rev_scale(predict(tlearnerw_t, test)$predictions, label = Y_train*3), 
                                              rev_scale(predict(tlearnerw_ut, test)$predictions, label = Y_train*3))))
      test_temp$tau_tlearnerw_rf <- list(tibble(tau_tlearnerw_rf = y_tl_t - y_tl_ut))
      
      ### T-learner + MCBoost (ridge)
      init_preds = function(data) {
        preds <- predict(tlearner_t, data)$predictions}
      tlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                                  auditor_fitter = "RidgeAuditorFitter",
                                  alpha = 1e-06,
                                  # iter_sampling = "bootstrap",
                                  weight_degree = 2,
                                  eta = 0.5,
                                  max_iter = 5)
      tlearner_t_mc$multicalibrate(X_auditt[X_auditt$T == 1, ], Y_audits[X_auditt$T == 1])

      yp_tlearner_t_mc_train <- tlearner_t_mc$predict_probs(train)
      yp_tlearner_t_mc_test <- tlearner_t_mc$predict_probs(test) 
      
      yp_tlearner_t_mc_trt <- tlearner_t_mc$predict_probs(train_t)
      y_tlearner_t_mc_trt <- rev_scale(yp_tlearner_t_mc_trt, label = Y_train*3)
      
      yp_tlearner_t_mc_tst <- tlearner_t_mc$predict_probs(test_t)
      y_tlearner_t_mc_tst <- rev_scale(yp_tlearner_t_mc_tst, label = Y_train*3)
      
      init_preds = function(data) {
        preds <- predict(tlearner_ut, data)$predictions}
      tlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                                   auditor_fitter = "RidgeAuditorFitter",
                                   alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                   weight_degree = 2,
                                   eta = 0.5,
                                   max_iter = 5)
      tlearner_ut_mc$multicalibrate(X_auditt[X_auditt$T == 0, ], Y_audits[X_auditt$T == 0])
    
      yp_tlearner_ut_mc_train <- tlearner_ut_mc$predict_probs(train)
      yp_tlearner_ut_mc_test <- tlearner_ut_mc$predict_probs(test) 
      
      train_temp$yp_tlearner_rf_mcr <- list(tibble(yp_tlearner_rf_mcr = ifelse(train$T == 1, yp_tlearner_t_mc_train, yp_tlearner_ut_mc_train)))
      train_temp$y_tlearner_rf_mcr <- list(tibble(y_tlearner_rf_mcr = rev_scale(train_temp$yp_tlearner_rf_mcr[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_rf_mcr <- list(tibble(yp_tlearner_rf_mcr = ifelse(test$T == 1, yp_tlearner_t_mc_test, yp_tlearner_ut_mc_test)))
      test_temp$y_tlearner_rf_mcr <- list(tibble(y_tlearner_rf_mcr = rev_scale(test_temp$yp_tlearner_rf_mcr[[1]], label = Y_train*3)))
      
      yp_tlearner_ut_mc_trt <- tlearner_ut_mc$predict_probs(train_ut)
      y_tlearner_ut_mc_trt <- rev_scale(yp_tlearner_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_rf_mcr <- list(tibble(tau_tlearner_rf_mcr = y_tlearner_t_mc_trt - y_tlearner_ut_mc_trt))
      
      yp_tlearner_ut_mc_tst <- tlearner_ut_mc$predict_probs(test_ut)      
      y_tlearner_ut_mc_tst <- rev_scale(yp_tlearner_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_rf_mcr <- list(tibble(tau_tlearner_rf_mcr = y_tlearner_t_mc_tst - y_tlearner_ut_mc_tst))
      
      ### T-learner + MCBoost (tree)
      init_preds = function(data) {
        preds <- predict(tlearner_t, data)$predictions}
      tlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                                  auditor_fitter = "TreeAuditorFitter",
                                  alpha = 1e-06,
                                  # iter_sampling = "bootstrap",
                                  weight_degree = 2,
                                  eta = 0.5,
                                  max_iter = 5)
      tlearner_t_mc$multicalibrate(X_auditt[X_auditt$T == 1, ], Y_audits[X_auditt$T == 1])
      
      yp_tlearner_t_mc_train <- tlearner_t_mc$predict_probs(train)
      yp_tlearner_t_mc_test <- tlearner_t_mc$predict_probs(test) 
      
      yp_tlearner_t_mc_trt <- tlearner_t_mc$predict_probs(train_t)
      y_tlearner_t_mc_trt <- rev_scale(yp_tlearner_t_mc_trt, label = Y_train*3)
      
      yp_tlearner_t_mc_tst <- tlearner_t_mc$predict_probs(test_t)
      y_tlearner_t_mc_tst <- rev_scale(yp_tlearner_t_mc_tst, label = Y_train*3)
      
      init_preds = function(data) {
        preds <- predict(tlearner_ut, data)$predictions}
      tlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                                   auditor_fitter = "TreeAuditorFitter",
                                   alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                   weight_degree = 2,
                                   eta = 0.5,
                                   max_iter = 5)
      tlearner_ut_mc$multicalibrate(X_auditt[X_auditt$T == 0, ], Y_audits[X_auditt$T == 0])
      
      yp_tlearner_ut_mc_train <- tlearner_ut_mc$predict_probs(train)
      yp_tlearner_ut_mc_test <- tlearner_ut_mc$predict_probs(test) 
      
      train_temp$yp_tlearner_rf_mct <- list(tibble(yp_tlearner_rf_mct = ifelse(train$T == 1, yp_tlearner_t_mc_train, yp_tlearner_ut_mc_train)))
      train_temp$y_tlearner_rf_mct <- list(tibble(y_tlearner_rf_mct = rev_scale(train_temp$yp_tlearner_rf_mct[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_rf_mct <- list(tibble(yp_tlearner_rf_mct = ifelse(test$T == 1, yp_tlearner_t_mc_test, yp_tlearner_ut_mc_test)))
      test_temp$y_tlearner_rf_mct <- list(tibble(y_tlearner_rf_mct = rev_scale(test_temp$yp_tlearner_rf_mct[[1]], label = Y_train*3)))
      
      yp_tlearner_ut_mc_trt <- tlearner_ut_mc$predict_probs(train_ut)
      y_tlearner_ut_mc_trt <- rev_scale(yp_tlearner_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_rf_mct <- list(tibble(tau_tlearner_rf_mct = y_tlearner_t_mc_trt - y_tlearner_ut_mc_trt))
      
      yp_tlearner_ut_mc_tst <- tlearner_ut_mc$predict_probs(test_ut)      
      y_tlearner_ut_mc_tst <- rev_scale(yp_tlearner_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_rf_mct <- list(tibble(tau_tlearner_rf_mct = y_tlearner_t_mc_tst - y_tlearner_ut_mc_tst))
      
      ### T-learner + MCBoost (ridge max_iter 10)
      init_preds = function(data) {
        preds <- predict(tlearner_t, data)$predictions}
      tlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                                  auditor_fitter = "RidgeAuditorFitter",
                                  alpha = 1e-06,
                                  # iter_sampling = "bootstrap",
                                  weight_degree = 2,
                                  eta = 0.5,
                                  max_iter = 10)
      tlearner_t_mc$multicalibrate(X_auditt[X_auditt$T == 1, ], Y_audits[X_auditt$T == 1])
      
      yp_tlearner_t_mc_train <- tlearner_t_mc$predict_probs(train)
      yp_tlearner_t_mc_test <- tlearner_t_mc$predict_probs(test) 
      
      yp_tlearner_t_mc_trt <- tlearner_t_mc$predict_probs(train_t)
      y_tlearner_t_mc_trt <- rev_scale(yp_tlearner_t_mc_trt, label = Y_train*3)
      
      yp_tlearner_t_mc_tst <- tlearner_t_mc$predict_probs(test_t)
      y_tlearner_t_mc_tst <- rev_scale(yp_tlearner_t_mc_tst, label = Y_train*3)
      
      init_preds = function(data) {
        preds <- predict(tlearner_ut, data)$predictions}
      tlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                                   auditor_fitter = "RidgeAuditorFitter",
                                   alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                   weight_degree = 2,
                                   eta = 0.5,
                                   max_iter = 10)
      tlearner_ut_mc$multicalibrate(X_auditt[X_auditt$T == 0, ], Y_audits[X_auditt$T == 0])
      
      yp_tlearner_ut_mc_train <- tlearner_ut_mc$predict_probs(train)
      yp_tlearner_ut_mc_test <- tlearner_ut_mc$predict_probs(test) 
      
      train_temp$yp_tlearner_rf_mclr <- list(tibble(yp_tlearner_rf_mclr = ifelse(train$T == 1, yp_tlearner_t_mc_train, yp_tlearner_ut_mc_train)))
      train_temp$y_tlearner_rf_mclr <- list(tibble(y_tlearner_rf_mclr = rev_scale(train_temp$yp_tlearner_rf_mclr[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_rf_mclr <- list(tibble(yp_tlearner_rf_mclr = ifelse(test$T == 1, yp_tlearner_t_mc_test, yp_tlearner_ut_mc_test)))
      test_temp$y_tlearner_rf_mclr <- list(tibble(y_tlearner_rf_mclr = rev_scale(test_temp$yp_tlearner_rf_mclr[[1]], label = Y_train*3)))
      
      yp_tlearner_ut_mc_trt <- tlearner_ut_mc$predict_probs(train_ut)
      y_tlearner_ut_mc_trt <- rev_scale(yp_tlearner_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_rf_mclr <- list(tibble(tau_tlearner_rf_mclr = y_tlearner_t_mc_trt - y_tlearner_ut_mc_trt))
      
      yp_tlearner_ut_mc_tst <- tlearner_ut_mc$predict_probs(test_ut)      
      y_tlearner_ut_mc_tst <- rev_scale(yp_tlearner_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_rf_mclr <- list(tibble(tau_tlearner_rf_mclr = y_tlearner_t_mc_tst - y_tlearner_ut_mc_tst))
      
      ### T-learner + MCBoost (tree max_iter 10)
      init_preds = function(data) {
        preds <- predict(tlearner_t, data)$predictions}
      tlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                                  auditor_fitter = "TreeAuditorFitter",
                                  alpha = 1e-06,
                                  # iter_sampling = "bootstrap",
                                  weight_degree = 2,
                                  eta = 0.5,
                                  max_iter = 10)
      tlearner_t_mc$multicalibrate(X_auditt[X_auditt$T == 1, ], Y_audits[X_auditt$T == 1])
      
      yp_tlearner_t_mc_train <- tlearner_t_mc$predict_probs(train)
      yp_tlearner_t_mc_test <- tlearner_t_mc$predict_probs(test) 
      
      yp_tlearner_t_mc_trt <- tlearner_t_mc$predict_probs(train_t)
      y_tlearner_t_mc_trt <- rev_scale(yp_tlearner_t_mc_trt, label = Y_train*3)
      
      yp_tlearner_t_mc_tst <- tlearner_t_mc$predict_probs(test_t)
      y_tlearner_t_mc_tst <- rev_scale(yp_tlearner_t_mc_tst, label = Y_train*3)
      
      init_preds = function(data) {
        preds <- predict(tlearner_ut, data)$predictions}
      tlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                                   auditor_fitter = "TreeAuditorFitter",
                                   alpha = 1e-06,
                                   # iter_sampling = "bootstrap",
                                   weight_degree = 2,
                                   eta = 0.5,
                                   max_iter = 10)
      tlearner_ut_mc$multicalibrate(X_auditt[X_auditt$T == 0, ], Y_audits[X_auditt$T == 0])
      
      yp_tlearner_ut_mc_train <- tlearner_ut_mc$predict_probs(train)
      yp_tlearner_ut_mc_test <- tlearner_ut_mc$predict_probs(test) 
      
      train_temp$yp_tlearner_rf_mclt <- list(tibble(yp_tlearner_rf_mclt = ifelse(train$T == 1, yp_tlearner_t_mc_train, yp_tlearner_ut_mc_train)))
      train_temp$y_tlearner_rf_mclt <- list(tibble(y_tlearner_rf_mclt = rev_scale(train_temp$yp_tlearner_rf_mclt[[1]], label = Y_train*3)))
      test_temp$yp_tlearner_rf_mclt <- list(tibble(yp_tlearner_rf_mclt = ifelse(test$T == 1, yp_tlearner_t_mc_test, yp_tlearner_ut_mc_test)))
      test_temp$y_tlearner_rf_mclt <- list(tibble(y_tlearner_rf_mclt = rev_scale(test_temp$yp_tlearner_rf_mclt[[1]], label = Y_train*3)))
      
      yp_tlearner_ut_mc_trt <- tlearner_ut_mc$predict_probs(train_ut)
      y_tlearner_ut_mc_trt <- rev_scale(yp_tlearner_ut_mc_trt, label = Y_train*3)
      train_temp$tau_tlearner_rf_mclt <- list(tibble(tau_tlearner_rf_mclt = y_tlearner_t_mc_trt - y_tlearner_ut_mc_trt))
      
      yp_tlearner_ut_mc_tst <- tlearner_ut_mc$predict_probs(test_ut)      
      y_tlearner_ut_mc_tst <- rev_scale(yp_tlearner_ut_mc_tst, label = Y_train*3)
      test_temp$tau_tlearner_rf_mclt <- list(tibble(tau_tlearner_rf_mclt = y_tlearner_t_mc_tst - y_tlearner_ut_mc_tst))
      
      ### T-learner (grf)
      tclearner_t <- regression_forest(Y = Y_trains[X_traint$T == 1], 
                                       X = X_traint[X_traint$T == 1, ])
      tclearner_ut <- regression_forest(Y = Y_trains[X_traint$T == 0], 
                                        X = X_traint[X_traint$T == 0, ])
      
      yp_tl_t <- predict(tclearner_t, train_t)$predictions
      yp_tl_ut <- predict(tclearner_ut, train_ut)$predictions
      y_tl_t <- rev_scale(yp_tl_t, label = Y_train*3)
      y_tl_ut <- rev_scale(yp_tl_ut, label = Y_train*3)
      train_temp$y_tlearner_crf <- list(tibble(y_tlearner_crf = ifelse(train$T == 1, 
                                               rev_scale(predict(tclearner_t, select(train, -Y))$predictions, label = Y_train*3), 
                                               rev_scale(predict(tclearner_ut, select(train, -Y))$predictions, label = Y_train*3))))
      train_temp$tau_tlearner_crf <- list(tibble(tau_tlearner_crf = y_tl_t - y_tl_ut))
      
      yp_tl_t <- predict(tclearner_t, test_t)$predictions
      yp_tl_ut <- predict(tclearner_ut, test_ut)$predictions
      y_tl_t <- rev_scale(yp_tl_t, label = Y_train*3)
      y_tl_ut <- rev_scale(yp_tl_ut, label = Y_train*3)
      test_temp$y_tlearner_crf <- list(tibble(y_tlearner_crf = ifelse(test$T == 1, 
                                              rev_scale(predict(tclearner_t, select(test, x1:x10, T))$predictions, label = Y_train*3), 
                                              rev_scale(predict(tclearner_ut, select(test, x1:x10, T))$predictions, label = Y_train*3))))
      test_temp$tau_tlearner_crf <- list(tibble(tau_tlearner_crf = y_tl_t - y_tl_ut))
      
      ## Train models (w. audit data)
      ### Causal Forest
      cforest_audit <- causal_forest(X_auditt[-T], Y_audit, T_audit)

      train_temp$tau_cforest_audit <- list(tibble(tau_cforest_audit = predict(cforest_audit, X_traint[-T])$predictions))
      test_temp$tau_cforest_audit <- list(tibble(tau_cforest_audit = predict(cforest_audit, X_test)$predictions))
      
      ### S-learner
      slearner_audit <- ranger(y = Y_audit, x = X_auditt)
      
      train_temp$y_slearner_audit <- list(tibble(y_slearner_audit = predict(slearner_audit, train)$predictions))
      test_temp$y_slearner_audit <- list(tibble(y_slearner_audit = predict(slearner_audit, test)$predictions))
      
      train_temp$tau_slearner_audit <- list(tibble(tau_slearner_audit = predict(slearner_audit, train_t)$predictions - predict(slearner_audit, train_ut)$predictions))
      test_temp$tau_slearner_audit <- list(tibble(tau_slearner_audit = predict(slearner_audit, test_t)$predictions - predict(slearner_audit, test_ut)$predictions))
      
      ### T-learner
      tlearner_audit_t <- ranger(y = Y_audit[X_auditt$T == 1], 
                                 x = X_auditt[X_auditt$T == 1, ])
      tlearner_audit_ut <- ranger(y = Y_audit[X_auditt$T == 0], 
                                  x = X_auditt[X_auditt$T == 0, ])
      
      train_temp$y_tlearner_audit <- list(tibble(y_tlearner_audit = ifelse(train$T == 1, 
                                                 predict(tlearner_audit_t, train)$predictions,
                                                 predict(tlearner_audit_ut, train)$predictions)))
      test_temp$y_tlearner_audit <- list(tibble(y_tlearner_audit = ifelse(test$T == 1, 
                                                predict(tlearner_audit_t, test)$predictions,
                                                predict(tlearner_audit_ut, test)$predictions)))
      
      train_temp$tau_tlearner_audit <- list(tibble(tau_tlearner_audit = predict(tlearner_audit_t, train_t)$predictions - predict(tlearner_audit_ut, train_ut)$predictions))
      test_temp$tau_tlearner_audit <- list(tibble(tau_tlearner_audit = predict(tlearner_audit_t, test_t)$predictions - predict(tlearner_audit_ut, test_ut)$predictions))

      ## Save results
      train_res <- rbind(train_res, train_temp)
      test_res <- rbind(test_res, test_temp)
    
      print(paste("e =", e, "s =", s, "i =", i))
    }
  }
}

## Combine results

train_r <- train_res %>% 
  unnest() %>%
  group_by(rep_i, trsize_s, shift_e) %>%
  summarise(flag = mean(flag),
            mean_ps = mean(ps),
            mean_shift = mean(shift),
            mean_KL = mean(KL),
            mean_y = mean(y),
            bias_cforest = mean(ate - mean(tau_cforest)),
            bias_cforestw = mean(ate - mean(tau_cforestw)),
            bias_slearnerRF = mean(ate - mean(tau_slearner_rf)),
            bias_slearnerwRF = mean(ate - mean(tau_slearnerw_rf)),
            bias_drlearner = mean(ate - mean(tau_drlearner)),
            bias_drlearnerRFmcr = mean(ate - mean(tau_drlearner_mcr)),
            bias_drlearnerRFmct = mean(ate - mean(tau_drlearner_mct)),
            bias_drlearnerRFmcfr = mean(ate - mean(tau_drlearner_mcfr)),
            bias_drlearnerRFmcft = mean(ate - mean(tau_drlearner_mcft)),
            bias_drlearnerRFmclr = mean(ate - mean(tau_drlearner_mclr)),
            bias_drlearnerRFmclt = mean(ate - mean(tau_drlearner_mclt)),
            bias_tlearnerRF = mean(ate - mean(tau_tlearner_rf)),
            bias_tlearnerwRF = mean(ate - mean(tau_tlearnerw_rf)),
            bias_tlearnerRFmcr = mean(ate - mean(tau_tlearner_rf_mcr)),
            bias_tlearnerRFmct = mean(ate - mean(tau_tlearner_rf_mct)),
            bias_tlearnerRFmclr = mean(ate - mean(tau_tlearner_rf_mclr)),
            bias_tlearnerRFmclt = mean(ate - mean(tau_tlearner_rf_mclt)),
            bias_tlearnerCRF = mean(ate - mean(tau_tlearner_crf)),
            bias_cforestAUDIT = mean(ate - mean(tau_cforest_audit)),
            bias_slearnerAUDIT = mean(ate - mean(tau_slearner_audit)),
            bias_tlearnerAUDIT = mean(ate - mean(tau_tlearner_audit)),
            mse_cforest = mean((tau - tau_cforest)^2),
            mse_cforestw = mean((tau - tau_cforestw)^2),
            mse_slearnerRF = mean((tau - tau_slearner_rf)^2),
            mse_slearnerwRF = mean((tau - tau_slearnerw_rf)^2),
            mse_drlearner = mean((tau - tau_drlearner)^2),
            mse_drlearnerRFmcr = mean((tau - tau_drlearner_mcr)^2),
            mse_drlearnerRFmct = mean((tau - tau_drlearner_mct)^2),
            mse_drlearnerRFmcfr = mean((tau - tau_drlearner_mcfr)^2),
            mse_drlearnerRFmcft = mean((tau - tau_drlearner_mcft)^2),
            mse_drlearnerRFmclr = mean((tau - tau_drlearner_mclr)^2),
            mse_drlearnerRFmclt = mean((tau - tau_drlearner_mclt)^2),
            mse_tlearnerRF = mean((tau - tau_tlearner_rf)^2),
            mse_tlearnerwRF = mean((tau - tau_tlearnerw_rf)^2),
            mse_tlearnerRFmcr = mean((tau - tau_tlearner_rf_mcr)^2),
            mse_tlearnerRFmct = mean((tau - tau_tlearner_rf_mct)^2),
            mse_tlearnerRFmclr = mean((tau - tau_tlearner_rf_mclr)^2),
            mse_tlearnerRFmclt = mean((tau - tau_tlearner_rf_mclt)^2),
            mse_tlearnerCRF = mean((tau - tau_tlearner_crf)^2),
            mse_cforestAUDIT = mean((tau - tau_cforest_audit)^2),
            mse_slearnerAUDIT = mean((tau - tau_slearner_audit)^2),
            mse_tlearnerAUDIT = mean((tau - tau_tlearner_audit)^2)) 

test_r <- test_res %>% 
  unnest() %>%
  group_by(rep_i, trsize_s, shift_e) %>%
  summarise(flag = mean(flag),
            mean_ps = mean(ps),
            mean_shift = mean(shift),
            mean_KL = mean(KL),
            mean_y = mean(y),
            bias_cforest = mean(ate - mean(tau_cforest)),
            bias_cforestw = mean(ate - mean(tau_cforestw)),
            bias_slearnerRF = mean(ate - mean(tau_slearner_rf)),
            bias_slearnerwRF = mean(ate - mean(tau_slearnerw_rf)),
            bias_drlearner = mean(ate - mean(tau_drlearner)),
            bias_drlearnerRFmcr = mean(ate - mean(tau_drlearner_mcr)),
            bias_drlearnerRFmct = mean(ate - mean(tau_drlearner_mct)),
            bias_drlearnerRFmcfr = mean(ate - mean(tau_drlearner_mcfr)),
            bias_drlearnerRFmcft = mean(ate - mean(tau_drlearner_mcft)),
            bias_drlearnerRFmclr = mean(ate - mean(tau_drlearner_mclr)),
            bias_drlearnerRFmclt = mean(ate - mean(tau_drlearner_mclt)),
            bias_tlearnerRF = mean(ate - mean(tau_tlearner_rf)),
            bias_tlearnerwRF = mean(ate - mean(tau_tlearnerw_rf)),
            bias_tlearnerRFmcr = mean(ate - mean(tau_tlearner_rf_mcr)),
            bias_tlearnerRFmct = mean(ate - mean(tau_tlearner_rf_mct)),
            bias_tlearnerRFmclr = mean(ate - mean(tau_tlearner_rf_mclr)),
            bias_tlearnerRFmclt = mean(ate - mean(tau_tlearner_rf_mclt)),
            bias_tlearnerCRF = mean(ate - mean(tau_tlearner_crf)),
            bias_cforestAUDIT = mean(ate - mean(tau_cforest_audit)),
            bias_slearnerAUDIT = mean(ate - mean(tau_slearner_audit)),
            bias_tlearnerAUDIT = mean(ate - mean(tau_tlearner_audit)),
            mse_cforest = mean((tau - tau_cforest)^2),
            mse_cforestw = mean((tau - tau_cforestw)^2),
            mse_slearnerRF = mean((tau - tau_slearner_rf)^2),
            mse_slearnerwRF = mean((tau - tau_slearnerw_rf)^2),
            mse_drlearner = mean((tau - tau_drlearner)^2),
            mse_drlearnerRFmcr = mean((tau - tau_drlearner_mcr)^2),
            mse_drlearnerRFmct = mean((tau - tau_drlearner_mct)^2),
            mse_drlearnerRFmcfr = mean((tau - tau_drlearner_mcfr)^2),
            mse_drlearnerRFmcft = mean((tau - tau_drlearner_mcft)^2),
            mse_drlearnerRFmclr = mean((tau - tau_drlearner_mclr)^2),
            mse_drlearnerRFmclt = mean((tau - tau_drlearner_mclt)^2),
            mse_tlearnerRF = mean((tau - tau_tlearner_rf)^2),
            mse_tlearnerwRF = mean((tau - tau_tlearnerw_rf)^2),
            mse_tlearnerRFmcr = mean((tau - tau_tlearner_rf_mcr)^2),
            mse_tlearnerRFmct = mean((tau - tau_tlearner_rf_mct)^2),
            mse_tlearnerRFmclr = mean((tau - tau_tlearner_rf_mclr)^2),
            mse_tlearnerRFmclt = mean((tau - tau_tlearner_rf_mclt)^2),
            mse_tlearnerCRF = mean((tau - tau_tlearner_crf)^2),
            mse_cforestAUDIT = mean((tau - tau_cforest_audit)^2),
            mse_slearnerAUDIT = mean((tau - tau_slearner_audit)^2),
            mse_tlearnerAUDIT = mean((tau - tau_tlearner_audit)^2))  

train_r_long <- train_r %>%
  select(rep_i, trsize_s, shift_e, flag, bias_cforest:mse_tlearnerAUDIT) %>%
  pivot_longer(cols = bias_cforest:mse_tlearnerAUDIT,
               names_to = c("Metric", "Method"),
               names_sep = "_")
  
test_r_long <- test_r %>%
  select(rep_i, trsize_s, shift_e, flag, bias_cforest:mse_tlearnerAUDIT) %>%
  pivot_longer(cols = bias_cforest:mse_tlearnerAUDIT,
               names_to = c("Metric", "Method"),
               names_sep = "_")

save(train_r, test_r,
     train_r_long, test_r_long,
     file = "simu.RData")

## Evaluation
### Assumptions/ Overlap

train_res %>%
  filter(rep_i == 1) %>%
  group_by(trsize_s, shift_e) %>%
  rowwise() %>% 
  summarize(min(ps), mean(ps), max(ps),
            min(shift), mean(shift), max(shift))

test_res %>%
  filter(rep_i == 1) %>%
  group_by(trsize_s, shift_e) %>%
  rowwise() %>% 
  summarize(min(ps), mean(ps), max(ps),
            min(shift), mean(shift), max(shift))

train_res %>%
  filter(rep_i == 1) %>%
  select(trsize_s, shift_e, shift) %>%
  unnest(cols = c(shift)) %>%
  ggplot(aes(x = shift)) +
  geom_histogram() +
  facet_grid(rows = vars(shift_e), cols = vars(trsize_s))

test_res %>%
  filter(rep_i == 1) %>%
  select(trsize_s, shift_e, shift) %>%
  unnest(cols = c(shift)) %>%
  ggplot(aes(x = shift)) +
  geom_histogram() +
  facet_grid(rows = vars(shift_e), cols = vars(trsize_s))

### KL divergence

kl_tab <- test_r %>%
  filter(shift_e <= 2) %>%
  group_by(trsize_s, shift_e) %>%
  summarize(kl = mean(mean_KL)) %>%
  mutate(across(kl, round, 2))

kl_ftab <- knitr::kable(kl_tab, format = 'latex')
writeLines(kl_ftab, 's1b_kl.tex')

### Bias and MSE

bias_tab <- test_r %>%
  filter(shift_e <= 2) %>%
  group_by(trsize_s, shift_e) %>%
  summarize(bias_cforest = mean(bias_cforest), 
            bias_cforestw = mean(bias_cforestw),
            bias_slearnerRF = mean(bias_slearnerRF),
            bias_slearnerwRF = mean(bias_slearnerwRF),
            bias_drlearner = mean(bias_drlearner), 
            bias_drlearnerRFmcfr = mean(bias_drlearnerRFmcfr), 
            bias_drlearnerRFmcft = mean(bias_drlearnerRFmcft),
            bias_tlearnerRF = mean(bias_tlearnerRF), 
            bias_tlearnerwRF = mean(bias_tlearnerwRF), 
            bias_tlearnerRFmcr = mean(bias_tlearnerRFmcr),
            bias_tlearnerRFmct = mean(bias_tlearnerRFmct)) %>%
  mutate(across(bias_cforest:bias_tlearnerRFmct, round, 2))

cols <- seq_len(ncol(bias_tab))
bias_ftab <- data.frame(matrix(NA, nrow = nrow(bias_tab), ncol = ncol(bias_tab)))
names(bias_ftab) <- names(bias_tab)

for(i in 1:nrow(bias_tab)) {
  minf <- order(as.numeric(abs(bias_tab[i, 3:ncol(bias_tab)])))[1]
  mins <- order(as.numeric(abs(bias_tab[i, 3:ncol(bias_tab)])))[2]
  bias_ftab[i, ] <- cell_spec(bias_tab[i, ], 'latex', 
                              bold = cols == minf+2,
                              italic = cols == mins+2)
}

bias_ftab <- knitr::kable(bias_ftab, format = 'latex')
writeLines(bias_ftab, 's1b_1_bias.tex')

mse_tab <- test_r %>%
  filter(shift_e <= 2) %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mse_cforest = mean(mse_cforest), 
            mse_cforestw = mean(mse_cforestw),
            mse_slearnerRF = mean(mse_slearnerRF),
            mse_slearnerwRF = mean(mse_slearnerwRF),
            mse_drlearner = mean(mse_drlearner),
            mse_drlearnerRFmcr = mean(mse_drlearnerRFmcr),
            mse_drlearnerRFmct = mean(mse_drlearnerRFmct),
            mse_tlearnerRF = mean(mse_tlearnerRF), 
            mse_tlearnerwRF = mean(mse_tlearnerwRF), 
            mse_tlearnerRFmcr = mean(mse_tlearnerRFmcr),
            mse_tlearnerRFmct = mean(mse_tlearnerRFmct)) %>%
  mutate(across(mse_cforest:mse_tlearnerRFmct, round, 2))

cols <- seq_len(ncol(mse_tab))
mse_ftab <- data.frame(matrix(NA, nrow = nrow(mse_tab), ncol = ncol(mse_tab)))
names(mse_ftab) <- names(mse_tab)

for(i in 1:nrow(mse_tab)) {
  minf <- order(as.numeric(mse_tab[i, 3:ncol(mse_tab)]))[1]
  mins <- order(as.numeric(mse_tab[i, 3:ncol(mse_tab)]))[2]
  mse_ftab[i, ] <- cell_spec(mse_tab[i, ], 'latex', 
                             bold = cols == minf+2,
                             italic = cols == mins+2)
}

mse_ftab <- knitr::kable(mse_ftab, format = 'latex')
writeLines(mse_ftab, 's1b_1_mse.tex')

train_r_long$Method <- recode_factor(train_r_long$Method, 
                                   "cforest" = "CForest-OS", 
                                   "cforestw" = "CForest-wOS", 
                                   "slearnerRF" = "S-learner-OS",
                                   "slearnerwRF" = "S-learner-wOS",
                                   "drlearner" = "DR-learner-OS", 
                                   "drlearnerRFmcfr" = "DR-learner-MC-Ridge",
                                   "drlearnerRFmcft" = "DR-learner-MC-Tree",
                                   "tlearnerRF" = "T-learner-OS", 
                                   "tlearnerwRF" = "T-learner-wOS",
                                   "tlearnerRFmcr" = "T-learner-MC-Ridge", 
                                   "tlearnerRFmct" = "T-learner-MC-Tree", 
                                   "tlearnerRFmclr" = "tlearnerRFmclr", 
                                   "tlearnerRFmclt" = "tlearnerRFmclt", 
                                   "tlearnerCRF" = "tlearnerCRF",
                                   "cforestAUDIT" = "CForest-AUDIT", 
                                   "slearnerAUDIT" = "S-learner-AUDIT",
                                   "tlearnerAUDIT" = "T-learner-AUDIT")
test_r_long$Method <- recode_factor(test_r_long$Method, 
                                   "cforest" = "CForest-OS", 
                                   "cforestw" = "CForest-wOS", 
                                   "slearnerRF" = "S-learner-OS",
                                   "slearnerwRF" = "S-learner-wOS",
                                   "drlearner" = "DR-learner-OS",
                                   "drlearnerRFmcfr" = "DR-learner-MC-Ridge",
                                   "drlearnerRFmcft" = "DR-learner-MC-Tree",
                                   "tlearnerRF" = "T-learner-OS", 
                                   "tlearnerwRF" = "T-learner-wOS",
                                   "tlearnerRFmcr" = "T-learner-MC-Ridge", 
                                   "tlearnerRFmct" = "T-learner-MC-Tree",
                                   "tlearnerRFmclr" = "tlearnerRFmclr", 
                                   "tlearnerRFmclt" = "tlearnerRFmclt",
                                   "tlearnerCRF" = "tlearnerCRF",
                                   "cforestAUDIT" = "CForest-AUDIT", 
                                   "slearnerAUDIT" = "S-learner-AUDIT",
                                   "tlearnerAUDIT" = "T-learner-AUDIT")

### Plots

train_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1b_1_train_pred-error.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1b_1_train_ate-bias.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1b_1_train_cate-mse.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "prerr") %>%
  # filter(flag == 0) %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1b_1_test_pred-error.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "bias") %>%
  # filter(flag == 0) %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1b_1_test_ate-bias.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "mse") %>%
  # filter(flag == 0) %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1b_1_test_cate-mse.pdf", width = 10, height = 6)

shift_labs <- c("no shift", "moderate shift", "strong shift")
names(shift_labs) <- c("0", "1", "2")

test_r_long %>%
  mutate(method = fct_rev(Method)) %>%
  # filter(flag == 0) %>%
  filter(!Method %in% c("tlearnerCRF", "tlearnerRFmclr", "tlearnerRFmclt",
                        "drlearnerRFmcr", "drlearnerRFmct", "drlearnerRFmclr", "drlearnerRFmclt")) %>%
  filter(!Method %in% c("CForest-AUDIT", "S-learner-AUDIT", "T-learner-AUDIT")) %>%
  filter(Metric == "bias") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "Bias (ATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e),
             labeller = labeller(shift_e = shift_labs)) +
  # scale_colour_brewer(palette = "Paired") +
  scale_colour_manual(values = RColorBrewer::brewer.pal(12, "Paired")) +
  coord_flip(ylim = c(-15, 10)) +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.title = element_blank(),
        text = element_text(size = 13))

ggsave("s1b_1_test_ate-bias-box.pdf", width = 10, height = 7)

test_r_long %>%
  mutate(method = fct_rev(Method)) %>%
  # filter(flag == 0) %>%
  filter(!Method %in% c("tlearnerCRF", "tlearnerRFmclr", "tlearnerRFmclt",
                        "drlearnerRFmcr", "drlearnerRFmct", "drlearnerRFmclr", "drlearnerRFmclt")) %>%
  filter(!Method %in% c("CForest-AUDIT", "S-learner-AUDIT", "T-learner-AUDIT")) %>%
  filter(Metric == "mse") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "MSE (CATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e), scales = "free",
             labeller = labeller(shift_e = shift_labs)) +
  # scale_colour_brewer(palette = "Paired") +
  scale_colour_manual(values = RColorBrewer::brewer.pal(12, "Paired")) +
  coord_flip() +
  facetted_pos_scales(y = list(shift_e == 0 ~ scale_y_continuous(limits = c(0, 80)),
                               shift_e == 1 ~ scale_y_continuous(limits = c(0, 150)),
                               shift_e == 2 ~ scale_y_continuous(limits = c(0, 350)))) +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank(),
        legend.title = element_blank(),
        text = element_text(size = 13))

ggsave("s1b_1_test_cate-mse-box.pdf", width = 10, height = 7)

## Aggregated bar plots

library(plyr)

data_summary <- function(data, varname, groupnames){
  summary_func <- function(x, col){
    c(mean = mean(x[[col]], na.rm=T), sd = sd(x[[col]], na.rm=T))
  }
  data_sum <- ddply(data, groupnames, .fun = summary_func, varname)
  data_sum <- rename(data_sum, c("mean" = varname))
  return(data_sum)
}

d_sum <- data_summary(test_r_long, varname = "value", 
                      groupnames = c("shift_e", "trsize_s", "Metric", "Method"))

shift_labs <- c("no shift", "moderate shift", "strong shift")
names(shift_labs) <- c("0", "1", "2")

d_sum %>%
  mutate(trsize_s = factor(trsize_s),
         Method = recode_factor(Method, 
                                "cforest" = "CForest-OS", 
                                "tlearnerRF" = "T-learner-OS", 
                                "tlearnerRFmcr" = "T-learner-MC-Ridge",
                                "drlearner" = "DR-learner-OS",
                                "drlearnerRFmcfr" = "DR-learner-MC-Ridge")) %>%
  filter(Metric == "mse") %>%
  filter(Method %in% c("CForest-OS", "T-learner-OS", "T-learner-MC-Ridge",
                       "DR-learner-OS", "DR-learner-MC-Ridge")) %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(x = trsize_s, y = value, fill = Method)) + 
  geom_bar(stat = "identity", color = "black", position = position_dodge()) +
  geom_errorbar(aes(ymin = value-sd, ymax = value+sd), 
                linewidth = 0.2, width = .02, position = position_dodge(.9)) +
  labs(y = "MSE (CATE)", x = "Train size") +
  #scale_fill_brewer(palette = "Paired") +
  scale_fill_manual(values = c(RColorBrewer::brewer.pal(8, "Paired")[7],
                               RColorBrewer::brewer.pal(6, "Paired")[1:5])) +
  #coord_cartesian(ylim = c(0, 100)) +
  facet_wrap(. ~ shift_e, scales = "free",
             labeller = labeller(shift_e = shift_labs)) +
  facetted_pos_scales(y = list(shift_e == 0 ~ scale_y_continuous(limits = c(0, 100)),
                               shift_e == 1 ~ scale_y_continuous(limits = c(0, 275)),
                               shift_e == 2 ~ scale_y_continuous(limits = c(0, 275)))) +
  theme(legend.title = element_blank(),
        text = element_text(size = 10))

ggsave("s1b_1_test_cate-mse.pdf", width = 10, height = 2.75)
