# Simulation of heterogeneous treatment effect estimation
# Scenarios as defined in setups.R:

# shift: sparseLinearLogit (simu1) 
# pcscore: rct5 (simu1)
# mu0: sparseLinearStrong (simu1)
# tau: sparseLinearWeak (simu1)

library(tidyverse)
library(randomForest)
library(grf)
library(mcboost)
library()
library(mlr3learners)
#library(causalToolbox)
source("setups.R")

library(future.apply) 
library(parallel)


## Prepare loop

train_res <- tibble()
train_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                     ps = NA, shift = NA,
                     tibble(y = NA), tibble(yp = NA),
                     tibble(yp_cforest = NA),
                     tibble(yp_slearner = NA),
                     tibble(yp_slearner_mc = NA),
                     tibble(yp_tlearner = NA),
                     tibble(yp_tlearner_mc = NA),
                     ate = NA, tibble(tau = NA),
                     tibble(tau_cforest = NA),
                     tibble(tau_slearner = NA),
                     tibble(tau_slearner_mc = NA),
                     tibble(tau_tlearner = NA),
                     tibble(tau_tlearner_mc = NA))

test_res <- tibble()
test_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                    ps = NA, shift = NA,
                    tibble(y = NA), tibble(yp = NA),
                    tibble(yp_cforest = NA),
                    tibble(yp_slearner = NA),
                    tibble(yp_slearner_mc = NA),
                    tibble(yp_tlearner = NA),
                    tibble(yp_tlearner_mc = NA),
                    ate = NA, tibble(tau = NA),
                    tibble(tau_cforest = NA),
                    tibble(tau_slearner = NA),
                    tibble(tau_slearner_mc = NA),
                    tibble(tau_tlearner = NA),
                    tibble(tau_tlearner_mc = NA))

s_range <- c(500,1000)
#s_range <- seq(500, 5000, by = 1000) # n obs training set
e_range <- seq(0, 4, by = 0.2) # shift amplifier
n_reps <- 20 # n repetitions



init_preds = function(data) {
  predict(model, data, type = "prob")[, 2]
}

ridge <- LearnerAuditorFitter$new(lrn("regr.glmnet", alpha = 0, lambda = 1))
tree <- LearnerAuditorFitter$new(lrn("regr.rpart", maxdepth = 2))

## Simulation

## Generate population
init_sim <- simulate_causal_experiment(ntrain = 100000, # n obs population
                                       dim = 5, # n covars
                                       alpha = .1, # corr
                                       feat_distribution = "normal", 
                                       pscore = "osSparse1Beta", 
                                       mu0 = "sparseLinearStrong",
                                       tau = "sparseLinearWeak") 

pop <- tibble(init_sim$feat_tr, 
              ps = init_sim$Wp_tr,
              T = init_sim$W_tr, 
              tau = init_sim$tau_tr, 
              Yp = init_sim$Yp_tr,
              Y = init_sim$Yobs_tr,
              shift = init_sim$shift_tr,
              shiftw = init_sim$shiftw_tr)


params <- expand.grid(s_range, e_range, 1:n_reps)
names(params) <- c("s","e", "n")
ss <- as.list(params$s)
es <- as.list(params$e)
ns <- as.list(params$n)

inds <- split(params, seq(nrow(params)))


mc_st_learn_replication <- function(inds) {
  train_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                       ps = NA, shift = NA,
                       tibble(y = NA), tibble(yp = NA),
                       tibble(yp_cforest = NA),
                       tibble(yp_slearner = NA),
                       tibble(yp_slearner_mc = NA),
                       tibble(yp_tlearner = NA),
                       tibble(yp_tlearner_mc = NA),
                       ate = NA, tibble(tau = NA),
                       tibble(tau_cforest = NA),
                       tibble(tau_slearner = NA),
                       tibble(tau_slearner_mc = NA),
                       tibble(tau_tlearner = NA),
                       tibble(tau_tlearner_mc = NA))
  test_temp <- tibble(rep_i = NA, trsize_s = NA, shift_e = NA,
                      ps = NA, shift = NA,
                      tibble(y = NA), tibble(yp = NA),
                      tibble(yp_cforest = NA),
                      tibble(yp_slearner = NA),
                      tibble(yp_slearner_mc = NA),
                      tibble(yp_tlearner = NA),
                      tibble(yp_tlearner_mc = NA),
                      ate = NA, tibble(tau = NA),
                      tibble(tau_cforest = NA),
                      tibble(tau_slearner = NA),
                      tibble(tau_slearner_mc = NA),
                      tibble(tau_tlearner = NA),
                      tibble(tau_tlearner_mc = NA))
  
  s <- inds$s
  e <- inds$e
  n <- inds$n
  ## Set shift
  pop$shift_s <- pop$shiftw^e # weights source
  pop$shift_t <- pop$shiftw^-e # weights target
  
  train_temp$shift_e <- e
  test_temp$shift_e <- e
  
  ## Set training set size
  train_size <- s
  test_size <- 5000
  audit_size <- 1000
  
  train_temp$trsize_s <- s
  test_temp$trsize_s <- s

  train_temp$rep_i <- i
  test_temp$rep_i <- i
  
  ## Sample from population
  train <- slice_sample(pop, n = train_size, weight_by = shift_s)
  audit <- slice_sample(pop, n = audit_size, weight_by = shift_s)
  test <- slice_sample(pop, n = test_size, weight_by = shift_t)
  
  train_temp$ps <- list(train$ps) # treatment propensities
  test_temp$ps <- list(test$ps)
  train_temp$shift <- list(train$shift) # source sample propensities
  test_temp$shift <- list(test$shift)
  train_temp$y <- list(train$Y) # true Y
  test_temp$y <- list(test$Y)
  train_temp$yp <- list(train$Yp) # true Y p's
  test_temp$yp <- list(test$Yp)
  train_temp$ate <- mean(train$tau) # true ATE
  test_temp$ate <- mean(test$tau)
  train_temp$tau <- list(train$tau) # true tau
  test_temp$tau <- list(test$tau)
  
  train_sub <- select(train, Y, T, x1:x5) # Original train data     
  X_train <- select(train, x1:x5) 
  Y_train <- train$Y
  T_train <- train$T
  X_audit <- select(audit, T, x1:x5) # Original audit data
  Y_audit <- audit$Y
  X_test <- select(test, x1:x5) # Original test data
  
  train_ut <- data.frame(X_train, T = 0) # Fix treated and untreated
  train_t <- data.frame(X_train, T = 1)
  audit_ut <- data.frame(X_audit, T = 0)
  audit_t <- data.frame(X_audit, T = 1)
  test_ut <- data.frame(X_test, T = 0)
  test_t <- data.frame(X_test, T = 1)
  
  ## Train models
  ### Causal Forest
  cforest <- causal_forest(X_train, Y_train, T_train)
  
  train_temp$tau_cforest <- list(predict(cforest)$predictions)
  test_temp$tau_cforest <- list(predict(cforest, X_test)$predictions)
  
  # average_treatment_effect(cforest, target.sample = 'all')
  # average_treatment_effect(cforest, target.sample = 'treated')
  
  ### S-learner
  slearner <- randomForest(as.factor(Y) ~ ., 
                           data = train_sub)
  
  train_temp$yp_slearner <- list(predict(slearner, train, "prob")[, 2])
  test_temp$yp_slearner <- list(predict(slearner, test, "prob")[, 2])
  
  train_temp$tau_slearner <- list(predict(slearner, train_t, "prob")[, 2] - predict(slearner, train_ut, "prob")[, 2])
  test_temp$tau_slearner <- list(predict(slearner, test_t, "prob")[, 2] - predict(slearner, test_ut, "prob")[, 2])
  
  ### S-learner + MCBoost
  model <- slearner
  slearner_mc = MCBoost$new(init_predictor = init_preds,
                            auditor_fitter = ridge,
                            max_iter = 10)
  slearner_mc$multicalibrate(X_audit, Y_audit)
  
  train_temp$yp_slearner_mc <- list(slearner_mc$predict_probs(train))
  test_temp$yp_slearner_mc <- list(slearner_mc$predict_probs(test))
  
  train_temp$tau_slearner_mc <- list(slearner_mc$predict_probs(train_t) - slearner_mc$predict_probs(train_ut))
  test_temp$tau_slearner_mc <- list(slearner_mc$predict_probs(test_t) - slearner_mc$predict_probs(test_ut))
  
  ### T-learner
  tlearner_t <- randomForest(as.factor(Y) ~ . -T, 
                             data = train_sub[train_sub$T==1, ])
  tlearner_ut <- randomForest(as.factor(Y) ~ . -T, 
                              data = train_sub[train_sub$T==0, ])
  
  train_temp$yp_tlearner <- list(ifelse(train$T==1, 
                                        predict(tlearner_t, train, "prob")[, 2],
                                        predict(tlearner_ut, train, "prob")[, 2]))
  test_temp$yp_tlearner <- list(ifelse(test$T==1, 
                                       predict(tlearner_t, test, "prob")[, 2],
                                       predict(tlearner_ut, test, "prob")[, 2]))
  
  train_temp$tau_tlearner <- list(predict(tlearner_t, train_t, "prob")[, 2] - predict(tlearner_ut, train_ut, "prob")[, 2])
  test_temp$tau_tlearner <- list(predict(tlearner_t, test_t, "prob")[, 2] - predict(tlearner_ut, test_ut, "prob")[, 2])
  
  ### T-learner + MCBoost
  model <- tlearner_t
  tlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                              auditor_fitter = ridge,
                              max_iter = 5)
  tlearner_t_mc$multicalibrate(X_audit[X_audit$T == 1, ], Y_audit[X_audit$T == 1])
  model <- tlearner_ut
  tlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                               auditor_fitter = ridge,
                               max_iter = 5)
  tlearner_ut_mc$multicalibrate(X_audit[X_audit$T == 0, ], Y_audit[X_audit$T == 0])
  
  train_temp$yp_tlearner_mc <- list(ifelse(train$T==1, 
                                           tlearner_t_mc$predict_probs(train),
                                           tlearner_ut_mc$predict_probs(train)))
  test_temp$yp_tlearner_mc <- list(ifelse(test$T==1, 
                                          tlearner_t_mc$predict_probs(test),
                                          tlearner_ut_mc$predict_probs(test)))
  
  train_temp$tau_tlearner_mc <- list(tlearner_t_mc$predict_probs(train_t) - tlearner_ut_mc$predict_probs(train_ut))
  test_temp$tau_tlearner_mc <- list(tlearner_t_mc$predict_probs(test_t) - tlearner_ut_mc$predict_probs(test_ut))
  
  ## Save results
  #train_res <- rbind(train_res, train_temp)
  #test_res <- rbind(test_res, test_temp)
  
  
  print(paste("e =", e, "s =", s, "i =", i))
  return(list(train_temp,test_temp))
}

results <- future_lapply(inds,mc_st_learn_replication)


## Combine results

train_r <- train_res %>% 
  rowwise() %>% 
  mutate(mean_ps = mean(ps),
         mean_shift = mean(shift),
         mean_y = mean(y),
         prerr_slearner = mean((yp - yp_slearner)^2),
         prerr_slearnermc = mean((yp - yp_slearner_mc)^2),
         prerr_tlearner = mean((yp - yp_tlearner)^2),
         prerr_tlearnermc = mean((yp - yp_tlearner_mc)^2),
         bias_cforest = ate - mean(tau_cforest),
         bias_slearner = ate - mean(tau_slearner),
         bias_slearnermc = ate - mean(tau_slearner_mc),
         bias_tlearner = ate - mean(tau_tlearner),
         bias_tlearnermc = ate - mean(tau_tlearner_mc),
         mse_cforest = mean((tau - tau_cforest)^2),
         mse_slearner = mean((tau - tau_slearner)^2),
         mse_slearnermc = mean((tau - tau_slearner_mc)^2),
         mse_tlearner = mean((tau - tau_tlearner)^2),
         mse_tlearnermc = mean((tau - tau_tlearner_mc)^2)) %>% 
  select(-c(ps:yp_tlearner_mc), -c(tau:tau_tlearner_mc)) %>% 
  ungroup()

test_r <- test_res %>% 
  rowwise() %>% 
  mutate(mean_ps = mean(ps),
         mean_shift = mean(shift),
         mean_y = mean(y),
         prerr_slearner = mean((yp - yp_slearner)^2),
         prerr_slearnermc = mean((yp - yp_slearner_mc)^2),
         prerr_tlearner = mean((yp - yp_tlearner)^2),
         prerr_tlearnermc = mean((yp - yp_tlearner_mc)^2),
         bias_cforest = ate - mean(tau_cforest),
         bias_slearner = ate - mean(tau_slearner),
         bias_slearnermc = ate - mean(tau_slearner_mc),
         bias_tlearner = ate - mean(tau_tlearner),
         bias_tlearnermc = ate - mean(tau_tlearner_mc),
         mse_cforest = mean((tau - tau_cforest)^2),
         mse_slearner = mean((tau - tau_slearner)^2),
         mse_slearnermc = mean((tau - tau_slearner_mc)^2),
         mse_tlearner = mean((tau - tau_tlearner)^2),
         mse_tlearnermc = mean((tau - tau_tlearner_mc)^2)) %>% 
  select(-c(ps:yp_tlearner_mc), -c(tau:tau_tlearner_mc)) %>%
  ungroup()

train_r_long <- train_r %>%
  select(rep_i, trsize_s, shift_e, prerr_slearner:mse_tlearnermc) %>%
  pivot_longer(cols = prerr_slearner:mse_tlearnermc,
               names_to = c("Metric", "Method"),
               names_sep = "_")

test_r_long <- test_r %>%
  select(rep_i, trsize_s, shift_e, prerr_slearner:mse_tlearnermc) %>%
  pivot_longer(cols = prerr_slearner:mse_tlearnermc,
               names_to = c("Metric", "Method"),
               names_sep = "_")

save(train_r, test_r,
     train_r_long, test_r_long,
     file = "simu.RData")

## Evaluation
### Assumptions/ Overlap

train_res %>%
  filter(rep_i == 1) %>%
  group_by(trsize_s, shift_e) %>%
  rowwise() %>% 
  summarize(min(ps), mean(ps), max(ps),
            min(shift), mean(shift), max(shift))

test_res %>%
  filter(rep_i == 1) %>%
  group_by(trsize_s, shift_e) %>%
  rowwise() %>% 
  summarize(min(ps), mean(ps), max(ps),
            min(shift), mean(shift), max(shift))

train_res %>%
  filter(rep_i == 1) %>%
  select(trsize_s, shift_e, shift) %>%
  unnest(cols = c(shift)) %>%
  ggplot(aes(x = shift)) +
  geom_histogram() +
  facet_grid(rows = vars(shift_e), cols = vars(trsize_s))

test_res %>%
  filter(rep_i == 1) %>%
  select(trsize_s, shift_e, shift) %>%
  unnest(cols = c(shift)) %>%
  ggplot(aes(x = shift)) +
  geom_histogram() +
  facet_grid(rows = vars(shift_e), cols = vars(trsize_s))

### Bias and MSE

test_r %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mean(prerr_slearner), mean(prerr_slearnermc),
            mean(prerr_tlearner), mean(prerr_tlearnermc))

test_r %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mean(bias_cforest), 
            mean(bias_slearner), mean(bias_slearnermc),
            mean(bias_tlearner), mean(bias_tlearnermc))

test_r %>%
  group_by(trsize_s, shift_e) %>%
  summarize(mean(mse_cforest), 
            mean(mse_slearner), mean(mse_slearnermc),
            mean(mse_tlearner), mean(mse_tlearnermc))

train_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth() +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1_train_pred-error.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth() +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1_train_ate-bias.pdf", width = 10, height = 6)

train_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth() +
  labs(x = "Shift Intensity", y = "MSE (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1_train_cate-mse.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "prerr") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth() +
  labs(x = "Shift Intensity", y = "Prediction MSE") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1_test_pred-error.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "bias") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth() +
  labs(x = "Shift Intensity", y = "Bias (ATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1_test_ate-bias.pdf", width = 10, height = 6)

test_r_long %>%
  filter(Metric == "mse") %>%
  ggplot(aes(x = shift_e, y = value, group = Method, color = Method)) +
  geom_smooth() +
  labs(x = "Shift Intensity", y = "MSE (CATE)") +
  facet_grid(cols = vars(trsize_s))

ggsave("s1_test_cate-mse.pdf", width = 10, height = 6)
