# Simulation of heterogeneous treatment effect estimation
# Setup: External shift between source (train, audit) and target (test) data
# Binary outcome

# Scenarios as defined in setups.R:

# external shift: sparseLinearLogit (s1a_1, s1a_2)
# pcscore: osSparse1Beta (s1a_1, s1a_2)
# mu0: sparseLinearStrong (s1a_1), sparseLinearWeak (s1a_2)
# tau: sparseLinearWeak (s1a_1), fullLocallyLinear (s1a_2)

library(tidyverse)
library(randomForest)
library(glmnet)
library(grf)
library(mcboost)
library(mlr3learners)
#library(causalToolbox)
source("setups.R")

## Prepare loop

train_res <- tibble()
train_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                     ps = NA, shift = NA,
                     tibble(y = NA), tibble(yp = NA),
                     tibble(yp_cforest = NA),
                     tibble(yp_slearner_rf = NA),
                     tibble(yp_tlearner_rf = NA),
                     tibble(yp_tlearner_rf_mc = NA),
                     tibble(yp_slearner_l = NA),
                     tibble(yp_tlearner_l = NA),
                     tibble(yp_tlearner_l_mc = NA),
                     ate = NA, tibble(tau = NA),
                     tibble(tau_cforest = NA),
                     tibble(tau_slearner_rf = NA),
                     tibble(tau_tlearner_rf = NA),
                     tibble(tau_tlearner_rf_mc = NA),
                     tibble(tau_slearner_l = NA),
                     tibble(tau_tlearner_l = NA),
                     tibble(tau_tlearner_l_mc = NA))

test_res <- tibble()
test_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                    ps = NA, shift = NA,
                    tibble(y = NA), tibble(yp = NA),
                    tibble(yp_cforest = NA),
                    tibble(yp_slearner_rf = NA),
                    tibble(yp_tlearner_rf = NA),
                    tibble(yp_tlearner_rf_mc = NA),
                    tibble(yp_slearner_l = NA),
                    tibble(yp_tlearner_l = NA),
                    tibble(yp_tlearner_l_mc = NA),
                    ate = NA, tibble(tau = NA),
                    tibble(tau_cforest = NA),
                    tibble(tau_slearner_rf = NA),
                    tibble(tau_tlearner_rf = NA),
                    tibble(tau_tlearner_rf_mc = NA),
                    tibble(tau_slearner_l = NA),
                    tibble(tau_tlearner_l = NA),
                    tibble(tau_tlearner_l_mc = NA))

s_range <- seq(500, 5000, by = 1500) # n obs training set
e_range <- seq(0, 3, by = 0.25) # shift amplifier
n_reps <- 20 # n repetitions

ridge <- LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 0, lambda = 1))
tree <- LearnerAuditorFitter$new(lrn("regr.rpart", maxdepth = 3))

## Simulation

## Generate population
init_sim <- simulate_causal_experiment(ntrain = 100000, # n obs population
                                       dim = 5, # n covars
                                       alpha = .1, # corr
                                       feat_distribution = "normal", 
                                       pscore = "osSparse1Beta", 
                                       mu0 = "sparseLinearWeak",
                                       tau = "fullLocallyLinear",
                                       shiftfun = "SparseLogitLinear1") 

pop <- tibble(init_sim$feat_tr, 
              ps = init_sim$Wp_tr,
              T = init_sim$W_tr, 
              tau = init_sim$tau_tr, 
              Yp = init_sim$Yp_tr,
              Y = init_sim$Yobs_tr,
              shift = init_sim$shift_tr,
              shiftw = init_sim$shiftw_tr)

for(e in e_range) {
  
  ## Set external shift
  pop$shift_s <- pop$shiftw^e # weights source
  pop$shift_t <- pop$shiftw^-e # weights target
  
  train_temp$shift_e <- e
  test_temp$shift_e <- e
  
  for(s in s_range) {
    
    ## Set training set size
    train_size <- s
    test_size <- 5000
    audit_size <- 500
    
    train_temp$trsize_s <- s
    test_temp$trsize_s <- s
    
    for(i in 1:n_reps) {
      
      train_temp$rep_i <- i
      test_temp$rep_i <- i

      ## Sample from population
      train <- slice_sample(pop, n = train_size, weight_by = shift_s)
      audit <- slice_sample(pop, n = audit_size, weight_by = shift_s)
      test <- slice_sample(pop, n = test_size, weight_by = shift_t)

      train_temp$ps <- list(train$ps) # treatment propensities
      test_temp$ps <- list(test$ps)
      train_temp$shift <- list(train$shift) # source sample propensities
      test_temp$shift <- list(test$shift)
      train_temp$y <- list(train$Y) # true Y
      test_temp$y <- list(test$Y)
      train_temp$yp <- list(train$Yp) # true Y p's
      test_temp$yp <- list(test$Yp)
      train_temp$ate <- mean(train$tau) # true ATE
      test_temp$ate <- mean(test$tau)
      train_temp$tau <- list(train$tau) # true tau
      test_temp$tau <- list(test$tau)

      train_sub <- select(train, Y, T, x1:x5) # Original train data     
      X_train <- select(train, x1:x5) 
      Y_train <- train$Y
      T_train <- train$T
      X_audit <- select(audit, T, x1:x5) # Original audit data
      Y_audit <- audit$Y
      test_sub <- select(test, Y, T, x1:x5) # Original test data
      X_test <- select(test, x1:x5)
      
      train_ut <- data.frame(X_train, T = 0) # Fix treated and untreated
      train_t <- data.frame(X_train, T = 1)
      test_ut <- data.frame(X_test, T = 0)
      test_t <- data.frame(X_test, T = 1)

      ## Train models
      ### Causal Forest
      cforest <- causal_forest(X_train, Y_train, T_train)

      train_temp$tau_cforest <- list(predict(cforest)$predictions)
      test_temp$tau_cforest <- list(predict(cforest, X_test)$predictions)

      # average_treatment_effect(cforest, target.sample = 'all')
      # average_treatment_effect(cforest, target.sample = 'treated')
    
      ### S-learner [random forest]
      slearner_rf <- randomForest(as.factor(Y) ~ ., 
                                  data = train_sub)

      train_temp$yp_slearner_rf <- list(predict(slearner_rf, train, "prob")[, 2])
      test_temp$yp_slearner_rf <- list(predict(slearner_rf, test, "prob")[, 2])
      
      train_temp$tau_slearner_rf <- list(predict(slearner_rf, train_t, "prob")[, 2] - predict(slearner_rf, train_ut, "prob")[, 2])
      test_temp$tau_slearner_rf <- list(predict(slearner_rf, test_t, "prob")[, 2] - predict(slearner_rf, test_ut, "prob")[, 2])

      ### S-learner [lasso]
      slearner_l <- cv.glmnet(as.matrix(train_sub[-train_sub$Y]), as.factor(Y_train), 
                              alpha = 1, family = "binomial")

      train_temp$yp_slearner_l <- list(predict(slearner_l, newx = as.matrix(train_sub[-train_sub$Y]), s = "lambda.min", type = "response"))
      test_temp$yp_slearner_l <- list(predict(slearner_l, newx = as.matrix(test_sub[-test_sub$Y]), s = "lambda.min", type = "response"))
      
      train_temp$tau_slearner_l <- list(predict(slearner_l, newx = as.matrix(train_t), s = "lambda.min", type = "response") - predict(slearner_l, newx = as.matrix(train_ut), s = "lambda.min", type = "response"))
      test_temp$tau_slearner_l <- list(predict(slearner_l, newx = as.matrix(test_t), s = "lambda.min", type = "response") - predict(slearner_l, newx = as.matrix(test_ut), s = "lambda.min", type = "response"))
      
      ### T-learner [random forest]
      tlearner_t <- randomForest(as.factor(Y) ~ . -T, 
                                 data = train_sub[train_sub$T == 1, ])
      tlearner_ut <- randomForest(as.factor(Y) ~ . -T, 
                                  data = train_sub[train_sub$T == 0, ])

      train_temp$yp_tlearner_rf <- list(ifelse(train$T == 1, predict(tlearner_t, train, "prob")[, 2], predict(tlearner_ut, train, "prob")[, 2]))
      test_temp$yp_tlearner_rf <- list(ifelse(test$T == 1, predict(tlearner_t, test, "prob")[, 2], predict(tlearner_ut, test, "prob")[, 2]))
      
      train_temp$tau_tlearner_rf <- list(predict(tlearner_t, train_t, "prob")[, 2] - predict(tlearner_ut, train_ut, "prob")[, 2])
      test_temp$tau_tlearner_rf <- list(predict(tlearner_t, test_t, "prob")[, 2] - predict(tlearner_ut, test_ut, "prob")[, 2])
    
      ### T-learner [random forest] + MCBoost
      init_preds_rf = function(data) {predict(tlearner_t, data, type = "prob")[, 2]}
      tlearner_t_mc = MCBoost$new(init_predictor = init_preds_rf,
                                  auditor_fitter = ridge,
                                  iter_sampling = "bootstrap",
                                  max_iter = 10)
      tlearner_t_mc$multicalibrate(X_audit[X_audit$T == 1, ], Y_audit[X_audit$T == 1])

      init_preds_rf = function(data) {predict(tlearner_ut, data, type = "prob")[, 2]}
      tlearner_ut_mc = MCBoost$new(init_predictor = init_preds_rf,
                                   auditor_fitter = ridge,
                                   iter_sampling = "bootstrap",
                                   max_iter = 10)
      tlearner_ut_mc$multicalibrate(X_audit[X_audit$T == 0, ], Y_audit[X_audit$T == 0])
    
      train_temp$yp_tlearner_rf_mc <- list(ifelse(train$T == 1, tlearner_t_mc$predict_probs(train), tlearner_ut_mc$predict_probs(train)))
      test_temp$yp_tlearner_rf_mc <- list(ifelse(test$T == 1, tlearner_t_mc$predict_probs(test), tlearner_ut_mc$predict_probs(test)))
      
      train_temp$tau_tlearner_rf_mc <- list(tlearner_t_mc$predict_probs(train_t) - tlearner_ut_mc$predict_probs(train_ut))
      test_temp$tau_tlearner_rf_mc <- list(tlearner_t_mc$predict_probs(test_t) - tlearner_ut_mc$predict_probs(test_ut))
      
      ### T-learner [lasso]
      tlearner_t <- cv.glmnet(as.matrix(train_sub[train_sub$T == 1, -train_sub$Y]), as.factor(Y_train[train_sub$T == 1]), 
                              alpha = 1, family = "binomial")
      tlearner_ut <- cv.glmnet(as.matrix(train_sub[train_sub$T == 0, -train_sub$Y]), as.factor(Y_train[train_sub$T == 0]), 
                              alpha = 1, family = "binomial")

      train_temp$yp_tlearner_l <- list(ifelse(train$T == 1, 
                                              predict(tlearner_t, newx = as.matrix(train_sub[-train_sub$Y]), s = "lambda.min", type = "response"),
                                              predict(tlearner_ut, newx = as.matrix(train_sub[-train_sub$Y]), s = "lambda.min", type = "response")))
      test_temp$yp_tlearner_l <- list(ifelse(test$T == 1, 
                                             predict(tlearner_t, newx = as.matrix(test_sub[-test_sub$Y]), s = "lambda.min", type = "response"),
                                             predict(tlearner_ut, newx = as.matrix(test_sub[-test_sub$Y]), s = "lambda.min", type = "response")))
      
      train_temp$tau_tlearner_l <- list(predict(tlearner_t, newx = as.matrix(train_t), s = "lambda.min", type = "response") - predict(tlearner_ut, newx = as.matrix(train_ut), s = "lambda.min", type = "response"))
      test_temp$tau_tlearner_l <- list(predict(tlearner_t, newx = as.matrix(test_t), s = "lambda.min", type = "response") - predict(tlearner_ut, newx = as.matrix(test_ut), s = "lambda.min", type = "response"))
      
      ### T-learner [lasso] + MCBoost
      init_preds_l = function(data) {predict(tlearner_t, newx = as.matrix(data), s = "lambda.min", type = "response")}
      tlearner_t_mc = MCBoost$new(init_predictor = init_preds_l,
                                  auditor_fitter = ridge,
                                  iter_sampling = "bootstrap",
                                  max_iter = 10)
      tlearner_t_mc$multicalibrate(X_audit[X_audit$T == 1, ], Y_audit[X_audit$T == 1])

      init_preds_l = function(data) {predict(tlearner_ut, newx = as.matrix(data), s = "lambda.min", type = "response")}
      tlearner_ut_mc = MCBoost$new(init_predictor = init_preds_l,
                                   auditor_fitter = ridge,
                                   iter_sampling = "bootstrap",
                                   max_iter = 10)
      tlearner_ut_mc$multicalibrate(X_audit[X_audit$T == 0, ], Y_audit[X_audit$T == 0])
      
      train_temp$yp_tlearner_l_mc <- list(ifelse(train$T == 1, tlearner_t_mc$predict_probs(train_sub[-train_sub$Y]), tlearner_ut_mc$predict_probs(train_sub[-train_sub$Y])))
      test_temp$yp_tlearner_l_mc <- list(ifelse(test$T == 1, tlearner_t_mc$predict_probs(test_sub[-test_sub$Y]), tlearner_ut_mc$predict_probs(test_sub[-test_sub$Y])))
      
      train_temp$tau_tlearner_l_mc <- list(tlearner_t_mc$predict_probs(train_t) - tlearner_ut_mc$predict_probs(train_ut))
      test_temp$tau_tlearner_l_mc <- list(tlearner_t_mc$predict_probs(test_t) - tlearner_ut_mc$predict_probs(test_ut))
      
      ## Save results
      train_res <- rbind(train_res, train_temp)
      test_res <- rbind(test_res, test_temp)
    
      print(paste("e =", e, "s =", s, "i =", i))
    }
  }
}

## Combine results

train_r <- train_res %>% 
  rowwise() %>% 
  mutate(mean_ps = mean(ps),
         mean_shift = mean(shift),
         mean_y = mean(y),
         prerr_slearnerRF = mean((yp - yp_slearner_rf)^2),
         prerr_tlearnerRF = mean((yp - yp_tlearner_rf)^2),
         prerr_tlearnerRFmc = mean((yp - yp_tlearner_rf_mc)^2),
         prerr_slearnerL = mean((yp - yp_slearner_l)^2),
         prerr_tlearnerL = mean((yp - yp_tlearner_l)^2),
         prerr_tlearnerLmc = mean((yp - yp_tlearner_l_mc)^2),
         bias_cforest = ate - mean(tau_cforest),
         bias_slearnerRF = ate - mean(tau_slearner_rf),
         bias_tlearnerRF = ate - mean(tau_tlearner_rf),
         bias_tlearnerRFmc = ate - mean(tau_tlearner_rf_mc),
         bias_slearnerL = ate - mean(tau_slearner_l),
         bias_tlearnerL = ate - mean(tau_tlearner_l),
         bias_tlearnerLmc = ate - mean(tau_tlearner_l_mc),
         mse_cforest = mean((tau - tau_cforest)^2),
         mse_slearnerRF = mean((tau - tau_slearner_rf)^2),
         mse_tlearnerRF = mean((tau - tau_tlearner_rf)^2),
         mse_tlearnerRFmc = mean((tau - tau_tlearner_rf_mc)^2),
         mse_slearnerL = mean((tau - tau_slearner_l)^2),
         mse_tlearnerL = mean((tau - tau_tlearner_l)^2),
         mse_tlearnerLmc = mean((tau - tau_tlearner_l_mc)^2)) %>% 
  select(-c(ps:yp_tlearner_l_mc), -c(tau:tau_tlearner_l_mc)) %>% 
  ungroup()

test_r <- test_res %>% 
  rowwise() %>% 
  mutate(mean_ps = mean(ps),
         mean_shift = mean(shift),
         mean_y = mean(y),
         prerr_slearnerRF = mean((yp - yp_slearner_rf)^2),
         prerr_tlearnerRF = mean((yp - yp_tlearner_rf)^2),
         prerr_tlearnerRFmc = mean((yp - yp_tlearner_rf_mc)^2),
         prerr_slearnerL = mean((yp - yp_slearner_l)^2),
         prerr_tlearnerL = mean((yp - yp_tlearner_l)^2),
         prerr_tlearnerLmc = mean((yp - yp_tlearner_l_mc)^2),
         bias_cforest = ate - mean(tau_cforest),
         bias_slearnerRF = ate - mean(tau_slearner_rf),
         bias_tlearnerRF = ate - mean(tau_tlearner_rf),
         bias_tlearnerRFmc = ate - mean(tau_tlearner_rf_mc),
         bias_slearnerL = ate - mean(tau_slearner_l),
         bias_tlearnerL = ate - mean(tau_tlearner_l),
         bias_tlearnerLmc = ate - mean(tau_tlearner_l_mc),
         mse_cforest = mean((tau - tau_cforest)^2),
         mse_slearnerRF = mean((tau - tau_slearner_rf)^2),
         mse_tlearnerRF = mean((tau - tau_tlearner_rf)^2),
         mse_tlearnerRFmc = mean((tau - tau_tlearner_rf_mc)^2),
         mse_slearnerL = mean((tau - tau_slearner_l)^2),
         mse_tlearnerL = mean((tau - tau_tlearner_l)^2),
         mse_tlearnerLmc = mean((tau - tau_tlearner_l_mc)^2)) %>% 
  select(-c(ps:yp_tlearner_l_mc), -c(tau:tau_tlearner_l_mc)) %>%
  ungroup()

train_r_long <- train_r %>%
  select(rep_i, trsize_s, shift_e, prerr_slearnerRF:mse_tlearnerLmc) %>%
  pivot_longer(cols = prerr_slearnerRF:mse_tlearnerLmc,
               names_to = c("Metric", "Method"),
               names_sep = "_")
  
test_r_long <- test_r %>%
  select(rep_i, trsize_s, shift_e, prerr_slearnerRF:mse_tlearnerLmc) %>%
  pivot_longer(cols = prerr_slearnerRF:mse_tlearnerLmc,
               names_to = c("Metric", "Method"),
               names_sep = "_")

save(train_r, test_r,
     train_r_long, test_r_long,
     file = "simu.RData")

## Evaluation
### Assumptions/ Overlap

train_res %>%
  filter(rep_i == 1) %>%
  group_by(trsize_s, shift_e) %>%
  rowwise() %>% 
  summarize(min(ps), mean(ps), max(ps),
            min(shift), mean(shift), max(shift))

test_res %>%
  filter(rep_i == 1) %>%
  group_by(trsize_s, shift_e) %>%
  rowwise() %>% 
  summarize(min(ps), mean(ps), max(ps),
            min(shift), mean(shift), max(shift))

train_res %>%
  filter(rep_i == 1) %>%
  select(trsize_s, shift_e, shift) %>%
  unnest(cols = c(shift)) %>%
  ggplot(aes(x = shift)) +
  geom_histogram() +
  facet_grid(rows = vars(shift_e), cols = vars(trsize_s))

test_res %>%
  filter(rep_i == 1) %>%
  select(trsize_s, shift_e, shift) %>%
  unnest(cols = c(shift)) %>%
  ggplot(aes(x = shift)) +
  geom_histogram() +
  facet_grid(rows = vars(shift_e), cols = vars(trsize_s))

### Bias and MSE

test_r %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mean(prerr_slearnerRF),
            mean(prerr_tlearnerRF), mean(prerr_tlearnerRFmc),
            mean(prerr_slearnerL),
            mean(prerr_tlearnerL), mean(prerr_tlearnerLmc))

test_r %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mean(bias_cforest), 
            mean(bias_slearnerRF),
            mean(bias_tlearnerRF), mean(bias_tlearnerRFmc),
            mean(bias_slearnerL),
            mean(bias_tlearnerL), mean(bias_tlearnerLmc))

test_r %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mean(mse_cforest), 
            mean(mse_slearnerRF),
            mean(mse_tlearnerRF), mean(mse_tlearnerRFmc),
            mean(mse_slearnerL),
            mean(mse_tlearnerL), mean(mse_tlearnerLmc))

train_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1a_1_train_pred-error.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1a_1_train_ate-bias.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1a_1_train_cate-mse.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1a_1_test_pred-error.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1a_1_test_ate-bias.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth(alpha = 0.25, size = 0.75) +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1a_1_test_cate-mse.pdf", width = 10, height = 6)
