# Load libraries, register cores
library(data.table)
library(mvnfast)
library(broom)
library(sisVIVE)
library(leakyIV)
library(ggplot2)
library(ggsci)
library(doParallel)
registerDoParallel(8)

# Load scripts
source('bayesian_baseline.R')
source('MBE.R')
source('simulator.R')

# Set seed
set.seed(123, kind = "L'Ecuyer-CMRG")

# Benchmark against backdoor adjustment, 2SLS, sisVIVE, and MBE:
bnchmrk <- function(d_z, z_rho, rho, r2_y, pr_valid, n, n_sim) {
  
  # Generate data, extract "population" data
  sim <- sim_dat4(n = 1e5, d_z, z_cnt = TRUE, z_rho, rho,
                  theta = 1, r2_x = 2/3, r2_y, pr_valid)
  d <- d_z + 2
  l2 <- sqrt(sum(sim$params$gamma^2))
  tau <- 1.1 * l2
  
  # Treat this as ground truth
  Sigma <- cov(sim$dat)
  dat <- as.data.table(rmvn(n * n_sim, mu = rep(0, d), Sigma))
  colnames(dat) <- c(paste0('z', seq_len(d_z)), 'x', 'y')
  
  # Inner loop
  inner_loop <- function(b) {
    
    # Draw data
    tmp <- dat[((b - 1) * n + 1):(b * n), ]
    
    # Run backdoor adjustment
    f0 <- lm(y ~ ., data = tmp)
    ate_bda <- as.numeric(tidy(f0)$estimate[d_z + 2])
    
    # Run 2SLS
    f0 <- lm(x ~ ., data = tmp[, -c('y')])
    tmp[, x_hat := fitted(f0)]
    f1 <- lm(y ~ x_hat, data = tmp)
    ate_2sls <- as.numeric(tidy(f1)$estimate[2])
    
    # Run sisVIVE
    sisvive <- cv.sisVIVE(tmp$y, tmp$x, as.matrix(tmp[, -c('x', 'y')]))
    ate_sisvive <- sisvive$beta
    
    # Run MBE
    beta_hat <- as.numeric(tidy(f0)$estimate[2:(d_z + 1)])
    se_beta <- as.numeric(tidy(f0)$std.error[2:(d_z + 1)])
    f2 <- lm(y ~ ., data = tmp[, -c('x_hat', 'x')])
    gamma_hat <- tidy(f2)$estimate[2:(d_z + 1)]
    se_gamma <- tidy(f2)$std.error[2:(d_z + 1)]
    mbe <- MBE(beta_hat, gamma_hat, se_beta, se_gamma, phi = 1, n_boot = 1)
    ate_mbe <- mbe$Estimate[2]
    
    # LeakyIV
    suppressWarnings(
      ate_bnds <- leakyIV(tmp$x, tmp$y, tmp[, -c('x', 'y')], tau = tau,
                          method = 'shrink')
      # ate_bnds <- leakyIV(0, 0, matrix(0, ncol = d_z), tau = tau, Sigma = Sigma)
    )
    ate_lo <- ate_bnds$ATE_lo
    ate_hi <- ate_bnds$ATE_hi
    
    # Export
    out <- data.table(b,
      method = c('Backdoor', 'TSLS', 'sisVIVE', 'MBE', 'Lower', 'Upper'),
      theta = c(ate_bda, ate_2sls, ate_sisvive, ate_mbe, ate_lo, ate_hi)
    )
    return(out)
  }
  out <- foreach(bb = seq_len(n_sim), .combine = rbind) %do% inner_loop(bb)
  
  # Export
  out[, d_z := d_z][, z_rho := z_rho][, rho := rho][, r2_y := r2_y][, pr_valid := pr_valid]
  return(out)
  
}

# Execute in parallel
df <- foreach(dd = c(5, 10), .combine = rbind) %:%
  foreach(zz = c(0, 0.5), .combine = rbind) %:%
  foreach(rr = seq(-0.9, 0.9, 0.1), .combine = rbind) %:%
  foreach(pve = c(1/3, 1/2, 2/3), .combine = rbind) %dopar%
  bnchmrk(dd, zz, rr, pve, pr_valid = 1/5, n = 1000, n_sim = 50)

################################################################################

# Spot check
df[method %in% c('Lower', 'Upper'), sum(is.na(theta)), 
   by = .(d_z, z_rho, rho, r2_y)]

# Plot it
df[, mu := mean(theta, na.rm = TRUE), by = .(method, d_z, z_rho, rho, r2_y)]
df[, se := sd(theta, na.rm = TRUE), by = .(method, d_z, z_rho, rho, r2_y)]
tmp <- unique(df[, .(method, mu, se, d_z, z_rho, rho, r2_y)])
tmp[, z_rho := fifelse(z_rho == 0, 'Diagonal', 'Toeplitz')]
tmp[r2_y == 1/3, SNR := 'SNR = 1/2']
tmp[r2_y == 1/2, SNR := 'SNR = 1']
tmp[r2_y == 2/3, SNR := 'SNR = 2']
tmp[, SNR := factor(SNR, levels = c('SNR = 1/2', 'SNR = 1', 'SNR = 2'))]
#tmp[, rho := fifelse(rho > 0.5, 'Strong Confounding', 'Weak Confounding')]
#tmp[, rho := factor(rho, levels = c('Weak Confounding', 'Strong Confounding'))]
tmp[method == 'TSLS', method := '2SLS']
setnames(tmp, 'method', 'Method')
tmp[, lo := .SD[Method == 'Lower', mu], by = .(d_z, z_rho, rho, r2_y)]
tmp[, hi := .SD[Method == 'Upper', mu], by = .(d_z, z_rho, rho, r2_y)]
tmp2 <- tmp[!Method %in% c('Lower', 'Upper')]
tmp2[, Method := factor(Method, levels = c('Backdoor', '2SLS', 'sisVIVE', 'MBE'))]



p1 <- ggplot(tmp2[SNR == 'SNR = 2'], aes(rho, mu, fill = Method)) + 
  geom_ribbon(aes(ymin = lo, ymax = hi, fill = 'LeakyIV'), alpha = 0.25) +
  geom_ribbon(aes(ymin = mu - se, ymax = mu + se), alpha = 0.5) + 
  geom_line(aes(color = Method)) + 
  geom_hline(yintercept = 1, linewidth = 0.5, color = 'black') +
  scale_color_d3(guide = 'none') +
  scale_fill_d3() +
  labs(x = expression(paste('Confounding Coefficient ', rho)),
       y = expression(paste('Average Treatment Effect ', theta))) +
  facet_grid(z_rho ~ d_z, scales = 'free', 
             labeller = label_bquote(cols = italic(d[Z])==.(d_z))) +
  theme_bw() + 
  theme(axis.title = element_text(size = 12),
        strip.text.x = element_text(size = 12),
        strip.text.y = element_text(size = 12),
        legend.title = element_text(size = 12),
        legend.text = element_text(size = 12),
        legend.position = 'bottom')

p2 <- ggplot(tmp2[z_rho == 'Diagonal'], aes(rho, mu, fill = Method)) + 
  geom_ribbon(aes(ymin = lo, ymax = hi, fill = 'LeakyIV'), alpha = 0.25) +
  geom_ribbon(aes(ymin = mu - se, ymax = mu + se), alpha = 0.5) + 
  geom_line(aes(color = Method)) + 
  geom_hline(yintercept = 1, linewidth = 0.5, color = 'black') +
  scale_color_d3(guide = 'none') +
  scale_fill_d3() +
  labs(x = expression(paste('Confounding Coefficient ', rho)),
       y = expression(paste('Average Treatment Effect ', theta))) +
  facet_grid(d_z ~ SNR, scales = 'free', 
             labeller = label_bquote(italic(d[Z])==.(d_z))) +
  theme_bw() + 
  theme(axis.title = element_text(size = 16),
        strip.text.x = element_text(size = 16),
        strip.text.y = element_text(size = 16),
        legend.title = element_text(size = 16),
        legend.text = element_text(size = 16),
        legend.position = 'bottom')
ggsave('./plots/benchmarks.pdf', width = 10)
  


