rm(list = ls())

##==============================================================================
## library and source files   
##==============================================================================
library(dplyr)
library(ggplot2)
library(foreign)
library(xgboost) # gradient boosting for model training
library(fastglm)
library(randomForest)
library(grf)
library(doMC) # parallel computing for cv.glmnet
library(glmnet)
library(ggpubr) # arrange multiple plots
library(doParallel)
library(haven) # import Stata file
library(RColorBrewer)
library(reshape2)
library(xtable)
# Source file
source( "functions_LDTE.R")
# Colors
cb_colors = brewer.pal(n = 8, name = "Dark2") # discrete colorblind palette

# Register cores for paralellization
num_cores = detectCores()
registerDoMC(cores = num_cores)
registerDoParallel(cores = num_cores) # Parallel backend

##==============================================================================
## load data
##==============================================================================
df.full = read.csv("./data/data_oregon_listetal_full.csv") # Y is outcome, W is random assignment, D is actual enrollment
summary(df.full)

df = df.full %>% 
  filter(numhh_list!=3) %>%
  na.omit() %>%
  mutate(S=numhh_list, Z=W) %>% # rename strata variable as S and IV variable as Z
  data.frame() 
summary(df)

# extract covariate names 
covariate_names = colnames(df[,6:33])

##------------------------------------------------------------------------------
## Estimation setup
##------------------------------------------------------------------------------
# Locations for DTE
vec.loc = seq(0, 15, by = 1) 
# Location for PTE
vec.loc.up = vec.loc
vec.loc.low = c(0-1, vec.loc[-length(vec.loc)])

# Number of folds for cross-fitting
F = 2
# Number of bootstrap replications
boot.size = 500

##------------------------------------------------------------------------------
## Regression-adjusted DTE and PTE
##------------------------------------------------------------------------------
# F-fold cross-fitting setup
set.seed(123)
df$folds = sample(c(1:F), nrow(df), replace=TRUE) 

### DTE
start_time = Sys.time()
res.dte = local.DTE.ML.estimation(df, vec.loc, "gradient_boosting", 1)
end_time = Sys.time()
print(paste("Time spent:", end_time-start_time))

ldte = res.dte$numerator/res.dte$denominator
ldte.ra = res.dte$numerator.ra/res.dte$denominator.ra

# Run bootstrap in parallel
start_time = Sys.time()
res.dte.boot = foreach(i=1:boot.size) %dopar%  local.DTE.bootstrap(i)
end_time = Sys.time()
print(paste("Time spent:", end_time-start_time))
saveRDS(res.dte.boot, file=paste0("./result/oregon/ldte_boot_estimates.rds"))

#res.dte.boot = readRDS(paste0("./result/oregon/ldte_boot_estimates.rds"))

# Calculate standard errors
ldte_boot_matrix = sapply(res.dte.boot, function(x) x$ldte) # Stack all LDTE estimates across bootstrap
ldte.ra_boot_matrix = sapply(res.dte.boot, function(x) x$ldte.ra ) # Stack all RA LDTE estimates across bootstrap

ldte.se = apply(ldte_boot_matrix, 1, sd)
ldte.ra.se = apply(ldte.ra_boot_matrix, 1, sd)


### PTE
start_time = Sys.time()
res.pte = local.PTE.ML.estimation(df, vec.loc.up, vec.loc.low, "gradient_boosting", 1)
end_time = Sys.time()
print(paste("Time spent:", end_time-start_time))

lpte = res.pte$numerator/res.pte$denominator
lpte.ra = res.pte$numerator.ra/res.pte$denominator.ra

# # Run bootstrap in parallel
start_time = Sys.time()
res.pte.boot = foreach(i=1:boot.size) %dopar%  local.PTE.bootstrap(i)
end_time = Sys.time()
print(paste("Time spent:", end_time-start_time))
 
saveRDS(res.pte.boot, file=paste0("./result/oregon/lpte_boot_estimates.rds"))
#res.pte.boot = readRDS(paste0("./result/oregon/lpte_boot_estimates.rds"))

# Calculate standard errors
lpte_boot_matrix = sapply(res.pte.boot, function(x) x$lpte) # Stack all LPTE estimates across bootstrap
lpte.ra_boot_matrix = sapply(res.pte.boot, function(x) x$lpte.ra ) # Stack all RA LPTE estimates across bootstrap

lpte.se = apply(lpte_boot_matrix, 1, sd)
lpte.ra.se = apply(lpte.ra_boot_matrix, 1, sd)




# ##------------------------------------------------------------------------------
# ## Plot LDTE (simple)
# ##------------------------------------------------------------------------------
ggplot() +
  geom_line( aes(vec.loc, ldte -1.96*ldte.se), color= cb_colors[4], linetype=2) +
  geom_line( aes(vec.loc, ldte +1.96*ldte.se), color= cb_colors[4], linetype=2) +
  geom_ribbon(aes(x    = vec.loc,
                  ymin = ldte -1.96*ldte.se,
                  ymax = ldte +1.96*ldte.se),
              fill = cb_colors[4], alpha = .3) +
  geom_line( aes(vec.loc, ldte), color = cb_colors[4]) +
  theme_bw() +
  #xlim(0, 200) +
  scale_x_continuous(breaks = seq(0,15,by=1), limit=c(0-0.5,15+0.5)) +
  geom_hline(yintercept=0, color="black", size=0.1, alpha = .3) +
  labs(title = "Empirical LDTE",
       x= "", y="Probability") +
  theme(text=element_text(size=17))+
  theme(axis.text.x = element_text(hjust=0.5),
        axis.text = element_text(size = 10),
        axis.title = element_text(size = 16, face = "bold"),
        plot.title = element_text(size = 16, face = "bold"),
        legend.text = element_text(size = 14, face = "bold"),
        legend.title = element_text(size = 16, face = "bold"),
        strip.text = element_text(size = 14, face = "bold"))

ggsave("./result/Oregon_ED_DTE_simple.pdf", width=5, height =3)  ## save 
ggsave("./result/Oregon_ED_DTE_simple.png", width=5, height =3)  ## save 


# ##------------------------------------------------------------------------------
# ## Plot LDTE (adjusted)
# ##------------------------------------------------------------------------------
ggplot() +
  geom_line( aes(vec.loc, ldte.ra -1.96*ldte.ra.se), color= cb_colors[5], linetype=2) +
  geom_line( aes(vec.loc, ldte.ra +1.96*ldte.ra.se), color= cb_colors[5], linetype=2) +
  geom_ribbon(aes(x = vec.loc,
                  ymin = ldte.ra -1.96*ldte.ra.se,
                  ymax = ldte.ra +1.96*ldte.ra.se),
              fill = cb_colors[5], alpha = .4)+
  geom_line( aes(vec.loc, ldte.ra), color = cb_colors[5]) +
  theme_bw() +
  #xlim(0, 200) +
  scale_x_continuous(breaks = seq(0,15,by=1), limit=c(0-0.5,15+0.5)) +
  geom_hline(yintercept=0, color="black", size=0.1, alpha = 0.3) +
  labs(title = "Adjusted LDTE",
       x= "", y="") +
  theme(axis.text.x = element_text(hjust=0.5),
        axis.text = element_text(size = 10),
        axis.title = element_text(size = 16, face = "bold"),
        plot.title = element_text(size = 16, face = "bold"),
        legend.text = element_text(size = 14, face = "bold"),
        legend.title = element_text(size = 16, face = "bold"),
        strip.text = element_text(size = 14, face = "bold"))

ggsave("./result/Oregon_ED_DTE_adj.pdf", width=5, height =3)  ## save 
ggsave("./result/Oregon_ED_DTE_adj.png", width=5, height =3)  ## save 


##------------------------------------------------------------------------------
## Plot LPTE (simple)
##------------------------------------------------------------------------------
y.max = max(max(lpte + 2*lpte.se), max(lpte.ra + 2*lpte.ra.se)) + 1e-5
y.min = min(min(lpte - 2*lpte.se), min(lpte.ra - 2*lpte.ra.se)) - 1e-5

ggplot() +
  geom_bar(aes(vec.loc, lpte), stat = "identity", color= cb_colors[4], fill=cb_colors[4]) +
  geom_errorbar(aes(x= vec.loc,
                    ymin = lpte - 1.96*lpte.se,
                    ymax = lpte + 1.96*lpte.se),
                    width= 0.5                   # Width of the error bars
                    #position=position_dodge(.9)
  ) +
  ylim(y.min, y.max) +
  geom_hline(yintercept=0, color="black", size=0.01, alpha = .7) +
  theme_bw() +
  labs(title="Empirical LPTE",
         x= "Number of ED visits", y="Probability")  +
  scale_x_continuous(breaks = seq(0,15,by=1), limit=c(0-0.5,15+0.5)) +
  theme(axis.text.x = element_text(hjust=0.5),
        axis.text = element_text(size = 10),
        axis.title = element_text(size = 16, face = "bold"),
        plot.title = element_text(size = 16, face = "bold"),
        legend.text = element_text(size = 14, face = "bold"),
        legend.title = element_text(size = 16, face = "bold"),
        strip.text = element_text(size = 14, face = "bold"))

ggsave("./result/Oregon_ED_PTE_simple.pdf", width=5, height =3)  ## save
ggsave("./result/Oregon_ED_PTE_simple.png", width=5, height =3)  ## save 


##------------------------------------------------------------------------------
## Plot LPTE (adjusted)
##------------------------------------------------------------------------------
ggplot( ) +
  geom_bar(aes(vec.loc, lpte.ra), stat = "identity", color= cb_colors[5], fill=cb_colors[5]) +
  geom_errorbar(aes(x= vec.loc,
                    ymin = lpte.ra - 1.96*lpte.ra.se,
                    ymax = lpte.ra + 1.96*lpte.ra.se),
                width= 0.5                   # Width of the error bars
                #position=position_dodge(.9)
  ) +
  ylim(y.min, y.max) +
  geom_hline(yintercept=0, color="black", size=0.01, alpha = .7) +
  theme_bw() +
  labs(title= "Adjusted LPTE", 
       x= "Number of ED visits", y="")  +
  scale_x_continuous(breaks = seq(0,15,by=1), limit=c(0-0.5, 15+0.5))+
  theme(axis.text.x = element_text( hjust=0.5),
        axis.text = element_text(size = 10),
        axis.title = element_text(size = 16, face = "bold"),
        plot.title = element_text(size = 16, face = "bold"),
        legend.text = element_text(size = 14, face = "bold"),
        legend.title = element_text(size = 16, face = "bold"),
        strip.text = element_text(size = 14, face = "bold"))

ggsave("./result/Oregon_ED_PTE_adj.pdf", width=5, height =3) ## save
ggsave("./result/Oregon_ED_PTE_adj.png", width=5, height =3) ## save






