# Simulation of heterogeneous treatment effect estimation
# Setup: External shift between murky source (train), target (test) and small RCT data (rct)
# Binary outcome

# Scenarios as defined in setups.R:

# s1c_1:
# dim: 10
# pcscore [train, test]: osSparse2Linear
# pcscore [rct]: rct5
# mu0: sparseNonLinear3
# tau: sparseNonLinear3
# external shift: SparseLogitLinear1

# s1c_2:
# dim: 10
# pcscore [train, test]: osConfounding2
# pcscore [rct]: rct5
# mu0: sparseNonLinear3
# tau: sparseNonLinear3
# external shift: SparseLogitLinear1

library(tidyverse)
library(ranger)
library(grf)
library(mcboost)
library(mlr3learners)
#library(causalToolbox)
source("setups.R")

## Prepare loop

train_res <- tibble()
train_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                     ps = NA,
                     tibble(y = NA), tibble(yp = NA),
                     ate = NA, tibble(tau = NA),
                     tibble(yp_cforest_tr = NA),
                     tibble(yp_slearner_tr = NA),
                     tibble(yp_slearner_tr_mcr = NA),
                     tibble(yp_slearner_tr_mct = NA),
                     tibble(yp_tlearner_tr = NA),
                     tibble(yp_tlearner_tr_mcr = NA),
                     tibble(yp_tlearner_tr_mct = NA),
                     tibble(tau_cforest_tr = NA),
                     tibble(tau_slearner_tr = NA),
                     tibble(tau_slearner_tr_mcr = NA),
                     tibble(tau_slearner_tr_mct = NA),
                     tibble(tau_tlearner_tr = NA),
                     tibble(tau_tlearner_tr_mcr = NA),
                     tibble(tau_tlearner_tr_mct = NA),
                     tibble(yp_cforest_rct = NA),
                     tibble(yp_cforestw_rct = NA),
                     tibble(yp_slearner_rct = NA),
                     tibble(yp_slearnerw_rct = NA),
                     tibble(yp_tlearner_rct = NA),
                     tibble(yp_tlearnerw_rct = NA),
                     tibble(tau_cforest_rct = NA),
                     tibble(tau_cforestw_rct = NA),
                     tibble(tau_slearner_rct = NA),
                     tibble(tau_slearnerw_rct = NA),
                     tibble(tau_tlearner_rct = NA),
                     tibble(tau_tlearnerw_rct = NA))

test_res <- tibble()
test_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                    ps = NA,
                    tibble(y = NA), tibble(yp = NA),
                    ate = NA, tibble(tau = NA),
                    tibble(yp_cforest_tr = NA),
                    tibble(yp_slearner_tr = NA),
                    tibble(yp_slearner_tr_mcr = NA),
                    tibble(yp_slearner_tr_mct = NA),
                    tibble(yp_tlearner_tr = NA),
                    tibble(yp_tlearner_tr_mcr = NA),
                    tibble(yp_tlearner_tr_mct = NA),
                    tibble(tau_cforest_tr = NA),
                    tibble(tau_slearner_tr = NA),
                    tibble(tau_slearner_tr_mcr = NA),
                    tibble(tau_slearner_tr_mct = NA),
                    tibble(tau_tlearner_tr = NA),
                    tibble(tau_tlearner_tr_mcr = NA),
                    tibble(tau_tlearner_tr_mct = NA),
                    tibble(yp_cforest_rct = NA),
                    tibble(yp_cforestw_rct = NA),
                    tibble(yp_slearner_rct = NA),
                    tibble(yp_slearnerw_rct = NA),
                    tibble(yp_tlearner_rct = NA),
                    tibble(yp_tlearnerw_rct = NA),
                    tibble(tau_cforest_rct = NA),
                    tibble(tau_cforestw_rct = NA),
                    tibble(tau_slearner_rct = NA),
                    tibble(tau_slearnerw_rct = NA),
                    tibble(tau_tlearner_rct = NA),
                    tibble(tau_tlearnerw_rct = NA))

rct_res <- tibble()
rct_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                   tibble(y = NA), tibble(yp = NA),
                   ate = NA, tibble(tau = NA),
                   tibble(tau_cforest_tr = NA),
                   tibble(tau_slearner_tr = NA),
                   tibble(tau_slearner_tr_mcr = NA),
                   tibble(tau_slearner_tr_mct = NA),
                   tibble(tau_tlearner_tr = NA),
                   tibble(tau_tlearner_tr_mcr = NA),
                   tibble(tau_tlearner_tr_mct = NA),
                   tibble(tau_cforest_rct = NA),
                   tibble(tau_cforestw_rct = NA),
                   tibble(tau_slearner_rct = NA),
                   tibble(tau_slearnerw_rct = NA),
                   tibble(tau_tlearner_rct = NA),
                   tibble(tau_tlearnerw_rct = NA))

s_range <- seq(500, 5000, by = 1500) # n obs training set
e_range <- seq(0, 3, by = 0.25) # shift amplifier
n_reps <- 20 # n repetitions

ridge <- LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 0, lambda = 1))
tree <- LearnerAuditorFitter$new(lrn("regr.rpart", maxdepth = 3))

## Simulation

## Generate RCT population
init_rct <- simulate_causal_experiment(ntrain = 100000, # n obs
                                       dim = 10, # n covars
                                       alpha = .1, # corr
                                       feat_distribution = "normal", 
                                       pscore = "rct5", 
                                       mu0 = "fullLocallyLinear",
                                       tau = "sparseLinearWeak",
                                       shiftfun = "SparseLogitLinear1") 

rct_pop <- tibble(init_rct$feat_tr, ps = init_rct$Wp_tr, T = init_rct$W_tr, 
                  tau = init_rct$tau_tr, Yp = init_rct$Yp_tr, Y = init_rct$Yobs_tr,
                  shift = init_rct$shift_tr, shiftw = init_rct$shiftw_tr)

for(e in e_range) {
  
  ## Set external shift
  rct_pop$shift_s <- rct_pop$shiftw^e # weights source
  rct_pop$shift_t <- rct_pop$shiftw^-e # weights target
  
  train_temp$shift_e <- e
  test_temp$shift_e <- e
  rct_temp$shift_e <- e
  
  for(s in s_range) {
  
    ## Set training set size
    train_size <- s
    test_size <- 5000
    rct_size <- 500
    
    train_temp$trsize_s <- s
    test_temp$trsize_s <- s
    rct_temp$trsize_s <- s
    
    for(i in 1:n_reps) {
      
      train_temp$rep_i <- i
      test_temp$rep_i <- i
      rct_temp$rep_i <- i

      ## Sample train
      init_train <- simulate_causal_experiment(ntrain = train_size, # n obs
                                              dim = 10, # n covars
                                              alpha = .1, # corr
                                              feat_distribution = "normal", 
                                              pscore = "osConfounding1", 
                                              mu0 = "fullLocallyLinear",
                                              tau = "sparseLinearWeak",
                                              shiftfun = "SparseLogitLinear1") 
    
      train <- tibble(init_train$feat_tr, ps = init_train$Wp_tr, T = init_train$W_tr, 
                      tau = init_train$tau_tr, Yp = init_train$Yp_tr, Y = init_train$Yobs_tr)
    
      ## Sample test
      init_test <- simulate_causal_experiment(ntrain = test_size, # n obs
                                              dim = 10, # n covars
                                              alpha = .1, # corr
                                              feat_distribution = "normal", 
                                              pscore = "osConfounding1", 
                                              mu0 = "fullLocallyLinear",
                                              tau = "sparseLinearWeak",
                                              shiftfun = "SparseLogitLinear1") 
    
      test <- tibble(init_test$feat_tr, ps = init_test$Wp_tr, T = init_test$W_tr, 
                     tau = init_test$tau_tr, Yp = init_test$Yp_tr, Y = init_test$Yobs_tr)
    
      ## Sample RCT
      rct <- slice_sample(rct_pop, n = rct_size, weight_by = shift_s)
      
      ## Pre-process data
      train_temp$ps <- list(train$ps) # treatment propensities
      test_temp$ps <- list(test$ps)
      train_temp$y <- list(train$Y) # true Y
      test_temp$y <- list(test$Y)
      rct_temp$y <- list(rct$Y)
      train_temp$yp <- list(train$Yp) # true Y p's
      test_temp$yp <- list(test$Yp)
      rct_temp$yp <- list(rct$Yp)
      train_temp$ate <- mean(train$tau) # true ATE
      test_temp$ate <- mean(test$tau)
      rct_temp$ate <- mean(rct$tau)
      train_temp$tau <- list(train$tau) # true tau
      test_temp$tau <- list(test$tau)
      rct_temp$tau <- list(rct$tau)

      train_f <- select(train, Y, T, x1:x10) # Complete train data
      train_r <- select(train, Y, T, x1:x10) # Incomplete train data
      X_traint_f <- select(train, T, x1:x10) 
      X_traint_r <- select(train, T, x1:x10) 
      Y_train <- train$Y
      T_train <- train$T
      rct_f <- select(rct, Y, T, x1:x10) # Complete RCT data
      rct_r <- select(rct, Y, T, x1:x10) # Incomplete RCT data
      X_rctt_f <- select(rct, T, x1:x10)
      X_rctt_r <- select(rct, T, x1:x10)
      Y_rct <- rct$Y
      T_rct <- rct$T
      X_test_f <- select(test, x1:x10) # Complete test data
      X_test_r <- select(test, x1:x10) # Incomplete test data
      
      train_ut <- data.frame(X_traint_f[-T], T = 0) # Fix treated and untreated
      train_t <- data.frame(X_traint_f[-T], T = 1)
      test_ut <- data.frame(X_test_f, T = 0)
      test_t <- data.frame(X_test_f, T = 1)
      rct_ut <- data.frame(X_rctt_f, T = 0)
      rct_t <- data.frame(X_rctt_f, T = 1)
      
      ## Propensity score model - Train vs RCT
      
      stacked <- bind_rows(X_traint_r[-T], X_rctt_r[-T], .id = "rct")
      stacked$rct <- as.numeric(stacked$rct) - 1
      psm <- glm(rct ~ ., family = binomial, data = stacked)
      pscores <- predict(psm, newdata = X_rctt_r[-T], type = "response")
      pweights <- (1 - pscores) / pscores
      
      ## Train models - Train w. train data, post-process w. RCT data
      ### Causal Forest
      cforest_tr <- causal_forest(X_traint_r[-T], Y_train, T_train)

      train_temp$tau_cforest_tr <- list(predict(cforest_tr)$predictions)
      test_temp$tau_cforest_tr <- list(predict(cforest_tr, X_test_r)$predictions)

      ### S-learner
      slearner_tr <- ranger(as.factor(Y) ~ ., 
                            probability = TRUE, 
                            data = train_r)

      train_temp$yp_slearner_tr <- list(predict(slearner_tr, train)$predictions[, 2])
      test_temp$yp_slearner_tr <- list(predict(slearner_tr, test)$predictions[, 2])
      
      train_temp$tau_slearner_tr <- list(predict(slearner_tr, train_t)$predictions[, 2] - predict(slearner_tr, train_ut)$predictions[, 2])
      test_temp$tau_slearner_tr <- list(predict(slearner_tr, test_t)$predictions[, 2] - predict(slearner_tr, test_ut)$predictions[, 2])

      ### S-learner + MCBoost (ridge)
      init_preds_rf = function(data) {predict(slearner_tr, data)$predictions[, 2]}
      slearner_tr_mc = MCBoost$new(init_predictor = init_preds_rf,
                                   auditor_fitter = ridge,
                                   iter_sampling = "bootstrap",
                                   max_iter = 20)
      slearner_tr_mc$multicalibrate(X_rctt_f, Y_rct)
      
      train_temp$yp_slearner_tr_mcr <- list(slearner_tr_mc$predict_probs(train))
      test_temp$yp_slearner_tr_mcr <- list(slearner_tr_mc$predict_probs(test))
      
      train_temp$tau_slearner_tr_mcr <- list(slearner_tr_mc$predict_probs(train_t)- slearner_tr_mc$predict_probs(train_ut))
      test_temp$tau_slearner_tr_mcr <- list(slearner_tr_mc$predict_probs(test_t) - slearner_tr_mc$predict_probs(test_ut))
      rct_temp$tau_slearner_tr_mcr <- list(slearner_tr_mc$predict_probs(rct_t) - slearner_tr_mc$predict_probs(rct_ut))

      ### S-learner + MCBoost (tree)
      init_preds_rf = function(data) {predict(slearner_tr, data)$predictions[, 2]}
      slearner_tr_mc = MCBoost$new(init_predictor = init_preds_rf,
                                   auditor_fitter = tree,
                                   iter_sampling = "bootstrap",
                                   max_iter = 20)
      slearner_tr_mc$multicalibrate(X_rctt_f, Y_rct)
      
      train_temp$yp_slearner_tr_mct <- list(slearner_tr_mc$predict_probs(train))
      test_temp$yp_slearner_tr_mct <- list(slearner_tr_mc$predict_probs(test))
      
      train_temp$tau_slearner_tr_mct <- list(slearner_tr_mc$predict_probs(train_t)- slearner_tr_mc$predict_probs(train_ut))
      test_temp$tau_slearner_tr_mct <- list(slearner_tr_mc$predict_probs(test_t) - slearner_tr_mc$predict_probs(test_ut))
      rct_temp$tau_slearner_tr_mct <- list(slearner_tr_mc$predict_probs(rct_t) - slearner_tr_mc$predict_probs(rct_ut))

      ### T-learner
      tlearner_tr_t <- ranger(as.factor(Y) ~ . -T, 
                              probability = TRUE,      
                              data = train_r[train_r$T == 1, ])
      tlearner_tr_ut <- ranger(as.factor(Y) ~ . -T, 
                               probability = TRUE,      
                               data = train_r[train_r$T == 0, ])

      train_temp$yp_tlearner_tr <- list(ifelse(train$T == 1, 
                                              predict(tlearner_tr_t, train)$predictions[, 2],
                                              predict(tlearner_tr_ut, train)$predictions[, 2]))
      test_temp$yp_tlearner_tr <- list(ifelse(test$T == 1, 
                                              predict(tlearner_tr_t, test)$predictions[, 2],
                                              predict(tlearner_tr_ut, test)$predictions[, 2]))
      
      train_temp$tau_tlearner_tr <- list(predict(tlearner_tr_t, train_t)$predictions[, 2] - predict(tlearner_tr_ut, train_ut)$predictions[, 2])
      test_temp$tau_tlearner_tr <- list(predict(tlearner_tr_t, test_t)$predictions[, 2] - predict(tlearner_tr_ut, test_ut)$predictions[, 2])
      rct_temp$tau_tlearner_tr <- list(predict(tlearner_tr_t, rct_t)$predictions[, 2] - predict(tlearner_tr_ut, rct_ut)$predictions[, 2])
      
      ### T-learner + MCBoost (ridge)
      init_preds = function(data) {predict(tlearner_tr_t, data)$predictions[, 2]}
      tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                    auditor_fitter = ridge,
                                    iter_sampling = "bootstrap",
                                    max_iter = 10)
      tlearner_tr_t_mc$multicalibrate(X_rctt_f[X_rctt_f$T == 1, ], Y_rct[X_rctt_f$T == 1])

      init_preds = function(data) {predict(tlearner_tr_ut, data)$predictions[, 2]}
      tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                      auditor_fitter = ridge,
                                      iter_sampling = "bootstrap",
                                      max_iter = 10)
      tlearner_tr_ut_mc$multicalibrate(X_rctt_f[X_rctt_f$T == 0, ], Y_rct[X_rctt_f$T == 0])
    
      train_temp$yp_tlearner_tr_mcr <- list(ifelse(train$T == 1, 
                                                  tlearner_tr_t_mc$predict_probs(train),
                                                  tlearner_tr_ut_mc$predict_probs(train)))
      test_temp$yp_tlearner_tr_mcr <- list(ifelse(test$T == 1, 
                                                tlearner_tr_t_mc$predict_probs(test),
                                                tlearner_tr_ut_mc$predict_probs(test)))
      
      train_temp$tau_tlearner_tr_mcr <- list(tlearner_tr_t_mc$predict_probs(train_t) - tlearner_tr_ut_mc$predict_probs(train_ut))
      test_temp$tau_tlearner_tr_mcr <- list(tlearner_tr_t_mc$predict_probs(test_t) - tlearner_tr_ut_mc$predict_probs(test_ut))
      rct_temp$tau_tlearner_tr_mcr <- list(tlearner_tr_t_mc$predict_probs(rct_t) - tlearner_tr_ut_mc$predict_probs(rct_ut))

      ### T-learner + MCBoost (tree)
      init_preds = function(data) {predict(tlearner_tr_t, data)$predictions[, 2]}
      tlearner_tr_t_mc = MCBoost$new(init_predictor = init_preds,
                                     auditor_fitter = tree,
                                     iter_sampling = "bootstrap",
                                     max_iter = 10)
      tlearner_tr_t_mc$multicalibrate(X_rctt_f[X_rctt_f$T == 1, ], Y_rct[X_rctt_f$T == 1])
      
      init_preds = function(data) {predict(tlearner_tr_ut, data)$predictions[, 2]}
      tlearner_tr_ut_mc = MCBoost$new(init_predictor = init_preds,
                                      auditor_fitter = tree,
                                      iter_sampling = "bootstrap",
                                      max_iter = 10)
      tlearner_tr_ut_mc$multicalibrate(X_rctt_f[X_rctt_f$T == 0, ], Y_rct[X_rctt_f$T == 0])
      
      train_temp$yp_tlearner_tr_mct <- list(ifelse(train$T == 1, 
                                                   tlearner_tr_t_mc$predict_probs(train),
                                                   tlearner_tr_ut_mc$predict_probs(train)))
      test_temp$yp_tlearner_tr_mct <- list(ifelse(test$T == 1, 
                                                  tlearner_tr_t_mc$predict_probs(test),
                                                  tlearner_tr_ut_mc$predict_probs(test)))
      
      train_temp$tau_tlearner_tr_mct <- list(tlearner_tr_t_mc$predict_probs(train_t) - tlearner_tr_ut_mc$predict_probs(train_ut))
      test_temp$tau_tlearner_tr_mct <- list(tlearner_tr_t_mc$predict_probs(test_t) - tlearner_tr_ut_mc$predict_probs(test_ut))
      rct_temp$tau_tlearner_tr_mct <- list(tlearner_tr_t_mc$predict_probs(rct_t) - tlearner_tr_ut_mc$predict_probs(rct_ut))
      
      ## Train models - Train w. RCT data, post-process w. train data
      ### Causal Forest
      cforest_rct <- causal_forest(X_rctt_f[-T], Y_rct, T_rct)
    
      train_temp$tau_cforest_rct <- list(predict(cforest_rct, X_traint_f[-T])$predictions)
      test_temp$tau_cforest_rct <- list(predict(cforest_rct, X_test_f)$predictions)
    
      ### Causal Forest (weighted)
      cforestw_rct <- causal_forest(X_rctt_f[-T], Y_rct, T_rct, sample.weights = pweights)
      
      train_temp$tau_cforestw_rct <- list(predict(cforestw_rct, X_traint_f[-T])$predictions)
      test_temp$tau_cforestw_rct <- list(predict(cforestw_rct, X_test_f)$predictions)
      
      ### S-learner
      slearner_rct <- ranger(as.factor(Y) ~ ., 
                             probability = TRUE,
                             data = rct_f)
    
      train_temp$yp_slearner_rct <- list(predict(slearner_rct, train)$predictions[, 2])
      test_temp$yp_slearner_rct <- list(predict(slearner_rct, test)$predictions[, 2])
    
      train_temp$tau_slearner_rct <- list(predict(slearner_rct, train_t)$predictions[, 2] - predict(slearner_rct, train_ut)$predictions[, 2])
      test_temp$tau_slearner_rct <- list(predict(slearner_rct, test_t)$predictions[, 2] - predict(slearner_rct, test_ut)$predictions[, 2])

      ### S-learner (weighted)
      slearnerw_rct <- ranger(as.factor(Y) ~ ., 
                             probability = TRUE,
                             case.weights = pweights,
                             data = rct_f)

      train_temp$yp_slearnerw_rct <- list(predict(slearnerw_rct, train)$predictions[, 2])
      test_temp$yp_slearnerw_rct <- list(predict(slearnerw_rct, test)$predictions[, 2])
      
      train_temp$tau_slearnerw_rct <- list(predict(slearnerw_rct, train_t)$predictions[, 2] - predict(slearnerw_rct, train_ut)$predictions[, 2])
      test_temp$tau_slearnerw_rct <- list(predict(slearnerw_rct, test_t)$predictions[, 2] - predict(slearnerw_rct, test_ut)$predictions[, 2])
      
      ### T-learner
      tlearner_rct_t <- ranger(as.factor(Y) ~ . -T, 
                               probability = TRUE,
                               data = rct_f[rct_f$T == 1, ])
      tlearner_rct_ut <- ranger(as.factor(Y) ~ . -T,
                                probability = TRUE,
                                data = rct_f[rct_f$T == 0, ])
    
      train_temp$yp_tlearner_rct <- list(ifelse(train$T == 1, 
                                                predict(tlearner_rct_t, train)$predictions[, 2],
                                                predict(tlearner_rct_ut, train)$predictions[, 2]))
      test_temp$yp_tlearner_rct <- list(ifelse(test$T == 1, 
                                              predict(tlearner_rct_t, test)$predictions[, 2],
                                              predict(tlearner_rct_ut, test)$predictions[, 2]))
    
      train_temp$tau_tlearner_rct <- list(predict(tlearner_rct_t, train_t)$predictions[, 2] - predict(tlearner_rct_ut, train_ut)$predictions[, 2])
      test_temp$tau_tlearner_rct <- list(predict(tlearner_rct_t, test_t)$predictions[, 2] - predict(tlearner_rct_ut, test_ut)$predictions[, 2])
      rct_temp$tau_tlearner_rct <- list(predict(tlearner_rct_t, rct_t)$predictions[, 2] - predict(tlearner_rct_ut, rct_ut)$predictions[, 2])
      
      ### T-learner (weighted)
      tlearnerw_rct_t <- ranger(as.factor(Y) ~ . -T, 
                               probability = TRUE,
                               case.weights = pweights[rct_f$T == 1],
                               data = rct_f[rct_f$T == 1, ])
      tlearnerw_rct_ut <- ranger(as.factor(Y) ~ . -T, 
                                probability = TRUE,
                                case.weights = pweights[rct_f$T == 0],
                                data = rct_f[rct_f$T == 0, ])
      
      train_temp$yp_tlearnerw_rct <- list(ifelse(train$T == 1, 
                                                predict(tlearnerw_rct_t, train)$predictions[, 2],
                                                predict(tlearnerw_rct_ut, train)$predictions[, 2]))
      test_temp$yp_tlearnerw_rct <- list(ifelse(test$T == 1, 
                                               predict(tlearnerw_rct_t, test)$predictions[, 2],
                                               predict(tlearnerw_rct_ut, test)$predictions[, 2]))

      train_temp$tau_tlearnerw_rct <- list(predict(tlearnerw_rct_t, train_t)$predictions[, 2] - predict(tlearnerw_rct_ut, train_ut)$predictions[, 2])
      test_temp$tau_tlearnerw_rct <- list(predict(tlearnerw_rct_t, test_t)$predictions[, 2] - predict(tlearnerw_rct_ut, test_ut)$predictions[, 2])
      rct_temp$tau_tlearnerw_rct <- list(predict(tlearnerw_rct_t, rct_t)$predictions[, 2] - predict(tlearnerw_rct_ut, rct_ut)$predictions[, 2])
      
      ## Save results
      train_res <- rbind(train_res, train_temp)
      test_res <- rbind(test_res, test_temp)
      rct_res <- rbind(rct_res, rct_temp)
      
      print(paste("e =", e, "s =", s, "i =", i))
    }
  }
}

## Combine results

train_r <- train_res %>% 
  rowwise() %>% 
  mutate(mean_ps = mean(ps),
         mean_y = mean(y),
         prerr_slearnerTR = mean((yp - yp_slearner_tr)^2),
         prerr_slearnerTRmcr = mean((yp - yp_slearner_tr_mcr)^2),
         prerr_slearnerTRmct = mean((yp - yp_slearner_tr_mct)^2),
         prerr_tlearnerTR = mean((yp - yp_tlearner_tr)^2),
         prerr_tlearnerTRmcr = mean((yp - yp_tlearner_tr_mcr)^2),
         prerr_tlearnerTRmct = mean((yp - yp_tlearner_tr_mct)^2),
         bias_cforestTR = ate - mean(tau_cforest_tr),
         bias_slearnerTR = ate - mean(tau_slearner_tr),
         bias_slearnerTRmcr = ate - mean(tau_slearner_tr_mcr),
         bias_slearnerTRmct = ate - mean(tau_slearner_tr_mct),
         bias_tlearnerTR = ate - mean(tau_tlearner_tr),
         bias_tlearnerTRmcr = ate - mean(tau_tlearner_tr_mcr),
         bias_tlearnerTRmct = ate - mean(tau_tlearner_tr_mct),
         mse_cforestTR = mean((tau - tau_cforest_tr)^2),
         mse_slearnerTR = mean((tau - tau_slearner_tr)^2),
         mse_slearnerTRmcr = mean((tau - tau_slearner_tr_mcr)^2),
         mse_slearnerTRmct = mean((tau - tau_slearner_tr_mct)^2),
         mse_tlearnerTR = mean((tau - tau_tlearner_tr)^2),
         mse_tlearnerTRmcr = mean((tau - tau_tlearner_tr_mcr)^2),
         mse_tlearnerTRmct = mean((tau - tau_tlearner_tr_mct)^2),
         prerr_slearnerRCT = mean((yp - yp_slearner_rct)^2),
         prerr_slearnerwRCT = mean((yp - yp_slearnerw_rct)^2),
         prerr_tlearnerRCT = mean((yp - yp_tlearner_rct)^2),
         prerr_tlearnerwRCT = mean((yp - yp_tlearnerw_rct)^2),
         bias_cforestRCT = ate - mean(tau_cforest_rct),
         bias_cforestwRCT = ate - mean(tau_cforestw_rct),
         bias_slearnerRCT = ate - mean(tau_slearner_rct),
         bias_slearnerwRCT = ate - mean(tau_slearnerw_rct),
         bias_tlearnerRCT = ate - mean(tau_tlearner_rct),
         bias_tlearnerwRCT = ate - mean(tau_tlearnerw_rct),
         mse_cforestRCT = mean((tau - tau_cforest_rct)^2),
         mse_cforestwRCT = mean((tau - tau_cforestw_rct)^2),
         mse_slearnerRCT = mean((tau - tau_slearner_rct)^2),
         mse_slearnerwRCT = mean((tau - tau_slearnerw_rct)^2),
         mse_tlearnerRCT = mean((tau - tau_tlearner_rct)^2),
         mse_tlearnerwRCT = mean((tau - tau_tlearnerw_rct)^2)) %>%
  select(-c(ps:yp_tlearner_tr_mcr), -c(tau:tau_tlearnerw_rct)) %>% 
  ungroup()

test_r <- test_res %>% 
  rowwise() %>% 
  mutate(mean_ps = mean(ps),
         mean_y = mean(y),
         prerr_slearnerTR = mean((yp - yp_slearner_tr)^2),
         prerr_slearnerTRmcr = mean((yp - yp_slearner_tr_mcr)^2),
         prerr_slearnerTRmct = mean((yp - yp_slearner_tr_mct)^2),
         prerr_tlearnerTR = mean((yp - yp_tlearner_tr)^2),
         prerr_tlearnerTRmcr = mean((yp - yp_tlearner_tr_mcr)^2),
         prerr_tlearnerTRmct = mean((yp - yp_tlearner_tr_mct)^2),
         bias_cforestTR = ate - mean(tau_cforest_tr),
         bias_slearnerTR = ate - mean(tau_slearner_tr),
         bias_slearnerTRmcr = ate - mean(tau_slearner_tr_mcr),
         bias_slearnerTRmct = ate - mean(tau_slearner_tr_mct),
         bias_tlearnerTR = ate - mean(tau_tlearner_tr),
         bias_tlearnerTRmcr = ate - mean(tau_tlearner_tr_mcr),
         bias_tlearnerTRmct = ate - mean(tau_tlearner_tr_mct),
         mse_cforestTR = mean((tau - tau_cforest_tr)^2),
         mse_slearnerTR = mean((tau - tau_slearner_tr)^2),
         mse_slearnerTRmcr = mean((tau - tau_slearner_tr_mcr)^2),
         mse_slearnerTRmct = mean((tau - tau_slearner_tr_mct)^2),
         mse_tlearnerTR = mean((tau - tau_tlearner_tr)^2),
         mse_tlearnerTRmcr = mean((tau - tau_tlearner_tr_mcr)^2),
         mse_tlearnerTRmct = mean((tau - tau_tlearner_tr_mct)^2),
         prerr_slearnerRCT = mean((yp - yp_slearner_rct)^2),
         prerr_slearnerwRCT = mean((yp - yp_slearnerw_rct)^2),
         prerr_tlearnerRCT = mean((yp - yp_tlearner_rct)^2),
         prerr_tlearnerwRCT = mean((yp - yp_tlearnerw_rct)^2),
         bias_cforestRCT = ate - mean(tau_cforest_rct),
         bias_cforestwRCT = ate - mean(tau_cforestw_rct),
         bias_slearnerRCT = ate - mean(tau_slearner_rct),
         bias_slearnerwRCT = ate - mean(tau_slearnerw_rct),
         bias_tlearnerRCT = ate - mean(tau_tlearner_rct),
         bias_tlearnerwRCT = ate - mean(tau_tlearnerw_rct),
         mse_cforestRCT = mean((tau - tau_cforest_rct)^2),
         mse_cforestwRCT = mean((tau - tau_cforestw_rct)^2),
         mse_slearnerRCT = mean((tau - tau_slearner_rct)^2),
         mse_slearnerwRCT = mean((tau - tau_slearnerw_rct)^2),
         mse_tlearnerRCT = mean((tau - tau_tlearner_rct)^2),
         mse_tlearnerwRCT = mean((tau - tau_tlearnerw_rct)^2)) %>%
  select(-c(ps:yp_tlearner_tr_mcr), -c(tau:tau_tlearnerw_rct)) %>%
  ungroup()

rct_r <- rct_res %>% 
  rowwise() %>% 
  mutate(d2bias_tlearnerTR = mean((tau - tau_tlearner_tr)*tau_tlearner_tr),
         d2bias_tlearnerTRmcr = mean((tau - tau_tlearner_tr_mcr)*tau_tlearner_tr_mcr),
         d2bias_tlearnerTRmct = mean((tau - tau_tlearner_tr_mct)*tau_tlearner_tr_mct),
         d2bias_tlearnerRCT = mean((tau - tau_tlearner_rct)*tau_tlearner_rct),
         d2bias_tlearnerwRCT = mean((tau - tau_tlearnerw_rct)*tau_tlearnerw_rct)) %>%
  select(-c(y:tau_tlearnerw_rct)) %>%
  ungroup()

train_r_long <- train_r %>%
  select(rep_i, trsize_s, shift_e, prerr_slearnerTR:mse_tlearnerwRCT) %>%
  pivot_longer(cols = prerr_slearnerTR:mse_tlearnerwRCT,
               names_to = c("Metric", "Method"),
               names_sep = "_")
  
test_r_long <- test_r %>%
  select(rep_i, trsize_s, shift_e, prerr_slearnerTR:mse_tlearnerwRCT) %>%
  pivot_longer(cols = prerr_slearnerTR:mse_tlearnerwRCT,
               names_to = c("Metric", "Method"),
               names_sep = "_")

rct_r_long <- rct_r %>%
  select(rep_i, trsize_s, shift_e, d2bias_tlearnerTR:d2bias_tlearnerwRCT) %>%
  pivot_longer(cols = d2bias_tlearnerTR:d2bias_tlearnerwRCT,
               names_to = c("Metric", "Method"),
               names_sep = "_")

save(train_r, test_r, rct_r,
     train_r_long, test_r_long, rct_r_long,
     file = "simu.RData")

## Evaluation
### Bias and MSE

test_r %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mean(bias_cforestTR), 
            mean(bias_slearnerTR),
            mean(bias_slearnerTRmcr),
            mean(bias_slearnerTRmct),
            mean(bias_tlearnerTR), 
            mean(bias_tlearnerTRmcr),
            mean(bias_tlearnerTRmct),
            mean(bias_cforestRCT), 
            mean(bias_slearnerRCT),
            mean(bias_tlearnerRCT))

test_r %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mean(mse_cforestTR), 
            mean(mse_slearnerTR),
            mean(mse_slearnerTRmcr),
            mean(mse_slearnerTRmct),
            mean(mse_tlearnerTR), 
            mean(mse_tlearnerTRmcr),
            mean(mse_tlearnerTRmct),
            mean(mse_cforestRCT), 
            mean(mse_slearnerRCT), 
            mean(mse_tlearnerRCT))

train_r_long$Method <- fct_relevel(train_r_long$Method, 
                                  "cforestRCT", "cforestwRCT", "cforestTR",
                                  "slearnerRCT", "slearnerwRCT", "slearnerTR", "slearnerTRmcr", "slearnerTRmct",
                                  "tlearnerRCT", "tlearnerwRCT",
                                  "tlearnerTR", "tlearnerTRmcr", "tlearnerTRmct")
test_r_long$Method <- fct_relevel(test_r_long$Method, 
                                  "cforestRCT", "cforestwRCT", "cforestTR",
                                  "slearnerRCT", "slearnerwRCT", "slearnerTR", "slearnerTRmcr", "slearnerTRmct",
                                  "tlearnerRCT", "tlearnerwRCT",
                                  "tlearnerTR", "tlearnerTRmcr", "tlearnerTRmct")
rct_r_long$Method <- fct_relevel(rct_r_long$Method, 
                                 "tlearnerRCT", "tlearnerwRCT",
                                 "tlearnerTR", "tlearnerTRmcr", "tlearnerTRmct")

test_r_long_sub1 <- test_r_long %>%
  filter(Metric == "bias") %>%
  rename(test_bias = value) %>%
  select(-Metric)

test_r_long_sub2 <- test_r_long %>%
  filter(Metric == "mse") %>%
  rename(test_mse = value) %>%
  select(-Metric)

rct_test_sub <- rct_r_long %>%
  rename(rct_d2bias = value) %>%
  select(-Metric) %>%
  left_join(test_r_long_sub1, by = c("rep_i", "trsize_s", "shift_e", "Method")) %>%
  left_join(test_r_long_sub2, by = c("rep_i", "trsize_s", "shift_e", "Method"))

### Plots

train_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1c_1_train_pred-error.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1c_1_train_ate-bias.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1c_1_train_cate-mse.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1c_1_test_pred-error.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1c_1_test_ate-bias.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1c_1_test_cate-mse.pdf", width = 10, height = 6)

test_r_long %>%
  mutate(method = fct_rev(Method)) %>%
  filter(Metric == "prerr") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "Prediction MSE") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e)) +
  coord_flip() +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank())

ggsave("s1c_1_test_pred-error-box.pdf", width = 10, height = 7)

test_r_long %>%
  mutate(method = fct_rev(Method)) %>%
  filter(Metric == "bias") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "Bias (ATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e)) +
  coord_flip() +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank())

ggsave("s1c_1_test_ate-bias-box.pdf", width = 10, height = 7)

test_r_long %>%
  mutate(method = fct_rev(Method)) %>%
  filter(Metric == "mse") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "MSE (CATE)") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e)) +
  coord_flip() +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank())

ggsave("s1c_1_test_cate-mse-box.pdf", width = 10, height = 7)

rct_r_long %>%
  mutate(method = fct_rev(Method)) %>%
  filter(Metric == "d2bias") %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = value, group = method, color = method)) +
  geom_boxplot(lwd = 0.4, outlier.size = 0.1) +
  geom_hline(aes(yintercept = 0)) +
  labs(y = "Degree-2 Bias") +
  guides(color = guide_legend(reverse = TRUE)) +
  scale_color_hue(direction = -1) +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e)) +
  coord_flip() +
  theme(axis.title.y = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank())

ggsave("s1c_1_rct_d2-bias-box.pdf", width = 10, height = 7)

rct_test_sub %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = abs(test_bias), x = abs(rct_d2bias), group = Method, color = Method)) +
  geom_point(size = 1) +
  labs(y = "Bias ATE (Test)", x = "Degree-2 Bias (RCT)") +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e))

ggsave("s1c_1_rct_d2-test_bias.pdf", width = 10, height = 7)

rct_test_sub %>%
  filter(shift_e == 0 | shift_e == 1 | shift_e == 2) %>%
  ggplot(aes(y = test_mse, x = abs(rct_d2bias), group = Method, color = Method)) +
  geom_point(size = 1) +
  labs(y = "MSE CATE (Test)", x = "Degree-2 Bias (RCT)") +
  facet_grid(rows = vars(trsize_s), cols = vars(shift_e))

ggsave("s1c_1_rct_d2-test_mse.pdf", width = 10, height = 7)
