#### Introduction ####
# 1. This file includes codes to reproduce results in the simulation studies in the paper
# ''A Non-parametric Direct Learning Approach to Heterogeneous Treatment Effect Estimation under Unmeasured Confounding''.

# 2. The authors run these on a Macbook with 2.9 GHz Quad-Core Intel Core i7 processor and 16 GB memory.
# The time used for Settings 1-4 were approximately 1hr, 8hrs, 1hr, and 1hr, respectively.

# 3. Execute the ''Packages'' and ''IV-DL functions'' sections before any of the Setting sections.
# The Settings are independent to each other.

# 4. The RD learning method and its core functions are excluded from this file because we only acquire permission to implement their algorithm for the research.

#### Packages ####
library(MASS)
library(glmnet)
library(pdist)
library(tidyverse)
library(WeightSVM)
library(bridgedist)
library(bartCause)
library(grf)
library(rpart)
library(caret)
library(locfit)

#### IV-DL functions ####

augments <- function(Y,A,Z,Dhat, pi_Zhat, cate_pre = (Yp-Yn)/Dhat,
                     Yp, Yn, Ap = 1, An, Ym = (Yp+Yn)/2, Am = (Ap+An)/2){
  #Wald estimator as cate_pre

  g <- (Yp+Yn)/2
  h1 <- Yp+(A-Ap-Z*Dhat)*cate_pre/2
  h2 <- Yn+(A-An-Z*Dhat)*cate_pre/2
  h3 <- Ym+(A-Am-Z*Dhat)*cate_pre/2
  Ymr <- Z*(Y-Yn-cate_pre*(A-An)/2)/Dhat/pi_Zhat + (Yp-Yn)/Dhat
  return(list(g = g, h1 = h1, h2 = h2, h3 = h3, Ymr = Ymr,
              Yp = Yp, Yn = Yn, Ap = Ap, An = An))
}

IVDL.linear <- function(Y,X,Z,pi_Z=1/2,d,g,lambda_grid = 5^(-10:10)){
  # Check arguments
  n <- nrow(X)
  stopifnot(is.numeric(Y))
  if (length(Y) != n) stop("'Y' has different length with number of rows in X")
  if (length(Z) != n) stop("'Z' has different length with number of rows in X")

  # cv fit
  response <- 2*Z*(Y-g)/d
  fit <- cv.glmnet(x=X, y=response, weights = 1/pi_Z, lambda = lambda_grid, nfolds = 10, family = "gaussian",
                   standardize = FALSE, intercept = TRUE)

  # Optimal lambda
  opt <- which(fit$lambda %in% fit$lambda.min)
  obj <- fit$cvm[opt]
  attr(obj, "type") <- "cv score"
  if (opt == 1 | opt == length(lambda_grid))
    warning("The optimal lambda for fitting treatment effect may fall outside the window")
  coef <- c(fit$glmnet.fit$a0[opt], fit$glmnet.fit$beta[, opt])
  names(coef)[1] <- "(Intercept)"
  return(list(coef = coef, opt.lambda = lambda_grid[opt]))
}

IVDL.kernel <- function(Y,X,Z,pi_Z=1/2,d,g,
                        qtile_grid = c(0.1,0.25,0.5,0.75,0.9),
                        lambda_grid = 10^(-5:5),
                        n_folds = 5){
  # Check arguments
  n <- nrow(X)
  stopifnot(is.numeric(Y))
  if (length(Y) != n) stop("'Y' has different length with number of rows in X")
  if (length(Z) != n) stop("'Z' has different length with number of rows in X")

  #cv KRR
  response <- 2*Z*(Y-g)/d

  fit <- WLS_ker(response, X, pi_Z, qtile_grid, lambda_grid, n_folds)
  return(list(coef = fit$coef,
              opt.qtile = fit$opt.qtile,
              opt.lambda = fit$opt.lambda))
}

WLS_ker <- function(Y, X, weights = NULL,
                    qtile_grid = c(0.1,0.25,0.5,0.75,0.9),
                    lambda_grid = 10^(-5:5),
                    n_folds = 5){

  #size of sample and tuning grids
  n <- length(Y)
  n_q <- length(qtile_grid)
  n_l <- length(lambda_grid)
  #tuning
  if(n_folds > 1){
    MSE <- matrix(nrow = n_folds, ncol = n_l*n_q)
    folds <- createFolds(1:n, k = n_folds)

    for (i in 1:n_folds) {
      ind <- folds[[i]]
      for (j in 1:n_q) {
        ker <- genKernel(X[-ind,], kernel = 'gaussian', qtile = qtile_grid[j])
        ker_te <- genKernel(X[-ind,], X[ind,], kernel = 'gaussian', qtile = qtile_grid[j])
        K <- cbind(1, ker)
        G <- rbind(0, cbind(0, ker))
        fit <- ridgereg(K, Y[-ind], G, weights[-ind], lambda_grid)
        Yhat_te <- cbind(1,ker_te)%*%fit$coef
        MSE[i,((j-1)*n_l+1):(j*n_l)] <- colMeans((sweep(x = Yhat_te, MARGIN = 1, STATS = Y[ind], FUN = "-"))^2)
      }
    }
    mMSE <- as.numeric(colMeans(MSE))

    seMSE <- apply(MSE,2,sd)/sqrt(n_folds)
    onesebound <-  mMSE[which.min(mMSE)] + seMSE[which.min(mMSE)]/2
    expand_grid_l <- rep(lambda_grid, times = n_q)
    expand_grid_q <- rep(qtile_grid, each = n_l)
    opt.qtile <- expand_grid_q[which.min(mMSE)]
    opt.lambda <- expand_grid_l[which.min(mMSE)]
    opt.lambda1se <- lambda_grid[max(which(mMSE[which(expand_grid_q==opt.qtile)]< onesebound))]

    ker <- genKernel(X, kernel = 'gaussian', qtile = opt.qtile)
    K <- cbind(1, ker)
    G <- rbind(0, cbind(0, ker))
    fit <- ridgereg(K, Y, G, weights, opt.lambda)
    fit1se <- ridgereg(K, Y, G, weights, opt.lambda1se)
  } else {
    ker <- genKernel(X, kernel = 'gaussian', qtile = 0.5)
    K <- cbind(1, ker)
    G <- rbind(0, cbind(0, ker))
    fit <- ridgereg(K, Y, G, weights, 0)
  }

  return(list(coef = as.vector(fit$coef),
              coef1se = as.vector(fit1se$coef),
              opt.qtile = opt.qtile,
              opt.lambda = opt.lambda,
              opt.lambda1se))
}

genKernel <- function(x, y, kernel = c("gaussian", "polynomial"),
                      epsilon = 1/2/quantile(d, qtile)^2, degree = 2, qtile = 0.5) {

  # Check arguments
  kernel <- match.arg(kernel)
  x <- as.matrix(x)
  y <- if (missing(y)) x else as.matrix(y)

  if (kernel == "gaussian") {
    d <- as.matrix(suppressWarnings(pdist(y, x)))
    ker <- exp(-epsilon * d^2)
    attr(ker, "epsilon") <- epsilon
  } else {
    ker <- (1 + y %*% t(x))^degree
    attr(ker, "degree") <- degree
  }

  attr(ker, "type") <- kernel
  return(ker)

}


ridgereg <- function(X, y, P, weights = NULL, lambda = 10^(-5:5)) {

  ## Minimize sum(w * (y - X * beta)^2)/n + lambda * t(beta) %*% P %*% beta
  ## X: design matrix
  ## y: response
  ## P: penalty matrix

  n <- length(y)
  N <- length(lambda)
  nlambda <- n * lambda
  if (!length(weights)) weights <- rep(1, n)

  # Normal equation (A + nlambda * P) %*% beta = B %*% y
  B <- t(X) %*% diag(weights)
  A <- B %*% X

  f <- matrix(0, n, N)
  beta <- matrix(0, ncol(X), N)
  gcv <- numeric(N)

  for (i in 1:N) {

    # Solve normal equation
    U <- ginv(A + nlambda[i] * P) %*% B
    beta[, i] <- U %*% y
    H <- X %*% U
    f[, i] <- H %*% y

    # GCV score
    rss <- sum(weights * (y - f[, i])^2)
    gcv[i] <- n * rss/(n - sum(diag(H)))^2

  }

  return(list(coef = beta, fitted = f, gcv = gcv, lambda = lambda))
}


IVDL.local <- function(Y,X, X_test, Z, pi_Z=1/2, d, g,
                       lambda_grid = 0,
                       qtile_grid = c(0.1,0.5,0.75,0.9, 1.5, 3, 10),
                       test_weights = NULL, n_folds = 5){
  #size of sample and tuning grids
  n <- length(Y)
  n_q <- length(qtile_grid)
  n_l <- length(lambda_grid)

  # cv fit
  response <- 2*Z*(Y-g)/d

  #tuning
  MSE <- matrix(nrow = n_folds, ncol = n_q*n_l)
  folds <- createFolds(1:n, k = n_folds)
  for (i in 1:n_folds) {
    ind <- folds[[i]]
    fit <- LocalReg(response[-ind], X[-ind,], X[ind,],
                    pi_Z[-ind], lambda_grid = lambda_grid, weights = NULL,
                    qtile_grid = qtile_grid)
    SE <- (sweep(fit$pred, 1, response[ind], FUN = "-"))^2
    MSE[i,] <- colMeans(SE)
  }

  mMSE <- as.numeric(colMeans(MSE))
  seMSE <- apply(MSE,2,sd)/sqrt(n_folds)
  onesebound <-  mMSE[which.min(mMSE)] + seMSE[which.min(mMSE)]
  opt.ind.qtile <- (which.min(mMSE)-1)%/%n_l+1
  mMSE_opt.qtile <- mMSE[((opt.ind.qtile-1)*n_l+1):(opt.ind.qtile*n_l)]
  opt.ind.lambda <- which(mMSE_opt.qtile <= onesebound)
  opt.lambda <- ifelse(n_l==1, lambda_grid, max(lambda_grid[opt.ind.lambda]))
  opt.qtile <- qtile_grid[opt.ind.qtile]

  fit_test <- LocalReg(response, X, X_test, pi_Z, lambda_grid = opt.lambda, weights = test_weights, qtile_grid = opt.qtile)
  fit <- LocalReg(response, X, X, pi_Z, lambda_grid = opt.lambda, weights = NULL, qtile_grid = opt.qtile)

  return(list(coef = fit_test$coef, pred = fit_test$pred,
              test_weights = fit_test$weights, fitted = fit$pred))
}


LocalReg <- function(Y, X_train, X_new, pi_Z, lambda_grid = 10^(-5:5),
                     weights = NULL, qtile_grid = c(0.1,0.5,0.75,0.9, 1.5, 3, 10), epsilon = NULL){
  #size of sample and tuning grids
  n <- length(Y)
  n_q <- length(qtile_grid)
  n_l <- length(lambda_grid)

  X <- cbind(1,X_train)
  if (is.null(weights) == T) {
    d <- as.matrix(suppressWarnings(pdist(as.matrix(cbind(1,X_new)), as.matrix(X))))
    epsilon_grid <- c(1/2/quantile(d, qtile_grid[which(qtile_grid<=1)])^2,
                      1/2/(max(d)*qtile_grid[which(qtile_grid>1)])^2)
    pred <- matrix(nrow = nrow(X_new), ncol = n_l*n_q)
    for (j in 1:n_q) {
      ker <- exp(-epsilon_grid[j] * d^2)
      weights <- sweep(t(ker), 1, pi_Z, FUN = "*")
      for (i in 1:n_l) {
        lambda = lambda_grid[i]
        beta <- weights %>% as_tibble %>%
          reframe(across(.col = everything(),
                         .fns = ~ginv(t(X)%*%diag(.x)%*%X+diag(lambda,ncol(X)))%*%t(X)%*%diag(.x)%*%Y))
        pred[,n_l*(j-1)+i] <- colSums(t(cbind(1, X_new))*as.matrix(beta))
      }
    }

  } else {
    pred <- matrix(nrow = nrow(X_new), ncol = length(lambda_grid))
    for (i in 1:length(lambda_grid)) {
      lambda = lambda_grid[i]
      beta <- weights %>% as_tibble %>%
        reframe(across(.col = everything(),
                       .fns = ~ginv(t(X)%*%diag(.x)%*%X+diag(lambda,ncol(X)))%*%t(X)%*%diag(.x)%*%Y))
      pred[,i] <- colSums(t(cbind(1, X_new))*as.matrix(beta))
    }
  }

  return(list(coef = beta, pred = pred,
              weights = weights,
              lambda_grid = lambda_grid,
              epsilon_grid = epsilon_grid))
}




#### Setting 1 Replication #####
nrep <- 500
#Testing Data
set.seed(100)
n_test <- 10000
X_test <- matrix(runif(n_test*5,-1,1),n_test,5)
p_Z_test <- 1/2
Z_test <- rbinom(n_test,1,p_Z_test)*2-1
U_test <- bridgedist::rbridge(n_test,1/2)
p_A_test <- expit(2*X_test[,1]+2.5*Z_test-0.5*U_test)
A_test <- rbinom(n_test,1,p_A_test)*2-1
h_test <- 0.5+X_test%*%c(0.5,0.8,0.3,-0.5,0.7)
q_test <- 0.2+X_test%*%c(-0.6,-0.8,0,0,0)
e_test <- rnorm(n_test,0,1)
Y_test <- h_test+q_test*A_test+0.5*U_test+e_test
v_test <- mean(h_test+0.5*U_test+e_test)
mean(abs(q_test)+v_test)


AR1_500 <- tibble(NULL)
Value1_500 <- tibble(NULL)
MSE1_500 <- tibble(NULL)


for (i in 1:nrep) {
  set.seed(i)
  n <- 500
  X <- matrix(runif(n*5,-1,1),n,5)
  p_Z <- 1/2
  Z <- rbinom(n,1,p_Z)*2-1
  U <- bridgedist::rbridge(n,1/2)
  p_A <- expit(2*X[,1]+2.5*Z-0.5*U)
  A <- rbinom(n,1,p_A)*2-1
  h <- 0.5+X%*%c(0.5,0.8,0.3,-0.5,0.7)
  q <- 0.2+X%*%c(-0.6,-0.8,0,0,0)
  e <- rnorm(n,0,1)
  Y <- h+q*A+0.5*U+e
  a <- (A+1)/2

  #### Estimation of nuisance ####
  #pi_Zhat
  reg_forest_Z.X <- regression_forest(X,Z,tune.parameters = "all",
                                      seed = i*1)
  Zhat <- reg_forest_Z.X$predictions[,1]
  pi_Zhat <- (Zhat+1)/2

  #conditional means
  XZ <- data.frame(X,Z)
  reg_forest_Y.XZ <- regression_forest(XZ,Y,tune.parameters = "all",
                                       seed = i*3)
  reg_forest_A.XZ <- regression_forest(XZ,A,tune.parameters = "all",
                                       seed = i*4)
  Yp <- predict(reg_forest_Y.XZ, newdata = data.frame(X,Z=1))$predictions
  Yn <- predict(reg_forest_Y.XZ, newdata = data.frame(X,Z=-1))$predictions
  An <- predict(reg_forest_A.XZ, newdata = data.frame(X,Z=-1))$predictions

  set.seed(i*6)
  #Logistic Regression
  modA <- glm(factor(A)~X+Z, family = 'binomial')
  Dhat_LogReg <- predict.glm(modA, newdata = data.frame(X=X,Z=1), type = 'response') -
    predict(modA,newdata = data.frame(X=X,Z=-1), type = 'response')

  #augments
  aug_LogReg <- augments(Y,A,Z,Dhat_LogReg, pi_Zhat, Yp = Yp, Yn = Yn, An = An)

  set.seed(i*7)

  ####Estimation of CATE####
  #IPW_MR with multiply robust
  W2MR <- aug_LogReg$Ymr
  tuning6 <- best.tune_wsvm(X,factor(sign(W2MR)), weight = abs(W2MR), kernel = 'linear',range = list(cost = 2^(-1:3)))
  AR_IW_MR <- mean(predict(tuning6, newdata = X_test)==sign(q_test))
  Value_IW_MR <- mean(h_test+q_test*(as.numeric(predict(tuning6,newdata = X_test))*2-3)+0.5*U_test+e_test)

  # #RD
  # modq_rd <- rdlearn(X,a+1,Y, method = "linear")
  # cate_te_rd <- predict(modq_rd, newx = X_test)[,2]-predict(modq_rd, newx = X_test)[,1]
  # AR_RD <- mean(sign(cate_te_rd)==sign(q_test))
  # Value_RD <- mean(q_test*sign(cate_te_rd))+v_test
  # MSE_RD <- mean((cate_te_rd-2*q_test)^2)

  #IV-DL
  fit_IVDL <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_LogReg, g=0)

  #IV-RDL1
  fit_IVRDL1 <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_LogReg, g=aug_LogReg$g)


  #IV-RDL2
  fit_IVRDL2 <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_LogReg, g=aug_LogReg$h2)


  ##MR2022
  reg_forest_Ymr <- regression_forest(X,aug_LogReg$Ymr,tune.parameters = "all",
                                      seed = i*8)
  cate_te_Ymr  <- predict(reg_forest_Ymr, newdata = X_test)$predictions

  #BART
  bart_fit <- bartc(Y,(A+1)/2,X, keepTrees = TRUE,
                    n.samples = 100, n.burn = 15, n.chains = 5,
                    seed = i*9)
  cate_te_BART <- colMeans(predict(bart_fit, newdata = X_test, type = 'icate'))

  #Causal Forest
  cf_fit <- instrumental_forest(X,Y,(A+1)/2,(Z+1)/2, tune.parameters = 'all',
                                seed = i*10)
  cate_te_CF <- predict(cf_fit, newdata = X_test)$predictions/2

  #### Results ####
  coef <- tibble(IVDL = fit_IVDL$coef,
                 IVRDL1 = fit_IVRDL1$coef,
                 IVRDL2 = fit_IVRDL2$coef)
  cate_te <- coef %>%
    reframe(across(everything(),.fns = ~as_vector(cbind(1,X_test)%*%as_vector(.x)))) %>%
    mutate(BART = cate_te_BART,
           CF = cate_te_CF,
           # RD =  cate_te_rd,
           MRIV = cate_te_Ymr, .before = everything())

  AR1_500 <- cate_te %>%
    summarise(across(.fns = ~mean(sign(.x)==sign(2*q_test)))) %>%
    mutate(IPW_MR = AR_IW_MR, .before = everything()) %>%
    tibble_row() %>%
    bind_rows(AR1_500, .)
  Value1_500 <- cate_te %>%
    summarise(across(.fns = ~mean(q_test*(sign(.x)))+v_test)) %>%
    mutate(IPW_MR = AR_IW_MR, .before = everything()) %>%
    tibble_row() %>%
    bind_rows(Value1_500, .)
  MSE1_500 <- cate_te %>%
    summarise(across(.fns = ~mean((.x-2*q_test)^2))) %>%
    tibble_row() %>%
    bind_rows(MSE1_500, .)

  msg <- paste0("first  setting rep ", i, " done")
  if(round(i/10)==i/10){
    print(msg)
  }
}


res1 <- bind_rows(AR1_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")"))),
                     Value1_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")"))),
                     MSE1_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")")))) %>%
  mutate(criteria = c("AR","Value","MSE"), .before = everything()) %>%
  pivot_longer(-1) %>% pivot_wider(names_from = "criteria", values_from = "value")
res1



####Setting 2 Replication #####
nrep <- 100
#Testing Data
set.seed(1)
n_test <- 5000
X_test <- matrix(runif(n_test*5,-1,1),n_test,5)
Z_test <- rbinom(n_test,1,1/2)*2-1
U_test <- rbridge(n_test,1/2)
p_A_test <- expit(2*X_test[,1]+2.5*Z_test-0.5*U_test)
A_test <- rbinom(n_test,1,p_A_test)*2-1
h_test <- 0.5+X_test%*%c(0.5,0.8,0.3,-0.5,0.7)
q_test <- exp(X_test%*%c(-0.6,-0.8,0,0,0))-1
e_test <- rnorm(n_test,0,1)
Y_test <- h_test+q_test*A_test+U_test+e_test
v_test <- mean(h_test+U_test+e_test)
v_test+mean(abs(q_test))

AR2_500 <- tibble(NULL)
Value2_500 <- tibble(NULL)
MSE2_500 <- tibble(NULL)


for (i in 1:nrep) {
  set.seed(i)
  n <- 500
  X <- matrix(runif(n*5,-1,1),n,5)
  Z <- rbinom(n,1,1/2)*2-1
  U <- rbridge(n,1/2)
  p_A <- expit(2*X[,1]+2.5*Z-0.5*U)
  A <- rbinom(n,1,p_A)*2-1
  h <- 0.5+X%*%c(0.5,0.8,0.3,-0.5,0.7)
  q <- exp(X%*%c(-0.6,-0.8,0,0,0))-1
  e <- rnorm(n,0,1)
  Y <- h+q*A+U+e
  a <- (A+1)/2

  #### Estimation of nuisance ####
  #pi_Zhat
  reg_forest_Z.X <- regression_forest(X,Z,tune.parameters = "all",
                                      seed = i*1)
  Zhat <- reg_forest_Z.X$predictions[,1]
  pi_Zhat <- (Zhat+1)/2

  #conditional means
  XZ <- data.frame(X,Z)
  reg_forest_Y.XZ <- regression_forest(XZ,Y,tune.parameters = "all",
                                       seed = i*3)
  reg_forest_A.XZ <- regression_forest(XZ,A,tune.parameters = "all",
                                       seed = i*4)
  Yp <- predict(reg_forest_Y.XZ, newdata = data.frame(X,Z=1))$predictions
  Yn <- predict(reg_forest_Y.XZ, newdata = data.frame(X,Z=-1))$predictions
  An <- predict(reg_forest_A.XZ, newdata = data.frame(X,Z=-1))$predictions


  ####Estimation of delta(x): CATE of Z on A####
  #CF
  cau_forest_delta <- causal_forest(X,(A+1)/2,(Z+1)/2,tune.parameters = "all",
                                    seed = i*5)
  Dhat_CF <- cau_forest_delta$predictions

  set.seed(i*6)
  #Logistic Regression
  modA <- glm(factor(A)~X+Z, family = 'binomial')
  Dhat_LogReg <- predict.glm(modA, newdata = data.frame(X=X,Z=1), type = 'response') -
    predict(modA,newdata = data.frame(X=X,Z=-1), type = 'response')

  #augments
  aug_CF <- augments(Y,A,Z,Dhat = Dhat_CF, pi_Zhat, Yp = Yp, Yn = Yn, An = An)

  set.seed(i*7)

  #IPW-MR
  W2MR <- aug_CF$Ymr
  tuning6 <- best.tune_wsvm(X,factor(sign(W2MR)), weight = abs(W2MR),
                            kernel = 'radial',range = list(gamma = 2^(-1:1),
                                                           cost = 2^(2:4)))
  AR_IW_MR <- mean(predict(tuning6, newdata = X_test)==sign(q_test))
  Value_IW_MR <- mean(h_test+q_test*(as.numeric(predict(tuning6, newdata = X_test))))

  # #RD
  # modq_rd <- rdlearn(X,a+1,Y, method = "kernel", kernel = 'gaussian')
  # cate_te_rd <- predict(modq_rd, newx = X_test)[,2]-predict(modq_rd, newx = X_test)[,1]

  #MR2022
  reg_forest_Ymr <- regression_forest(X,aug_CF$Ymr,tune.parameters = "all",
                                      seed = i*8)
  cate_te_Ymr  <- predict(reg_forest_Ymr, newdata = X_test)$predictions

  #BART
  bart_fit <- bartc(Y,(A+1)/2,X, keepTrees = TRUE,
                    n.samples = 100, n.burn = 15, n.chains = 5,
                    seed = i*9)
  cate_te_BART <- colMeans(predict(bart_fit, newdata = X_test, type = 'icate'))

  #CF
  cf_fit <- instrumental_forest(X,Y,(A+1)/2,(Z+1)/2,
                                seed = i*10)

  cate_te_CF <- predict(cf_fit, newdata = X_test)$predictions/2

  #We implemented local regression to estimate CATE
  #IV-DL
  fit_IVDL <- IVDL.local(Y,X,X_test,Z,pi_Zhat,Dhat_CF, g=0)
  cate_te_IVDL <- fit_IVDL$pred

  #IV-RDL1
  fit_IVRDL1 <- IVDL.local(Y,X,X_test,Z,pi_Zhat,Dhat_CF, g=aug_CF$g)
  cate_te_IVRDL1 <- fit_IVRDL1$pred

  #IV-RDL2
  fit_IVRDL2 <- IVDL.local(Y,X,X_test,Z,pi_Zhat,Dhat_CF, g=aug_CF$h2)
  cate_te_IVRDL2 <- fit_IVRDL2$pred

  #### RF ####
  fit_RF_DCF <- regression_forest(X,2*Z*(Y)/Dhat_CF, sample.weights = pi_Zhat, tune.parameters = "all")
  cate_te_RF_DCF <- predict(fit_RF_DCF, newdata = X_test)$predictions

  fit_RFg_DCF <- regression_forest(X,2*Z*(Y-aug_CF$g)/Dhat_CF, sample.weights = pi_Zhat, tune.parameters = "all")
  cate_te_RFg_DCF <- predict(fit_RFg_DCF, newdata = X_test)$predictions

  fit_RFh2_DCF <- regression_forest(X,2*Z*(Y-aug_CF$h2)/Dhat_CF, sample.weights = pi_Zhat, tune.parameters = "all")
  cate_te_RFh2_DCF <- predict(fit_RFh2_DCF, newdata = X_test)$predictions


  ####results####
  cate_te <- tibble(BART = cate_te_BART,
                    CF = cate_te_CF,
                    # RD = cate_te_rd,
                    MRIV = cate_te_Ymr,
                    IVDL = cate_te_IVDL,
                    IVRDL1 = cate_te_IVRDL1,
                    IVRDL2 = cate_te_IVRDL2)

  AR2_500 <- cate_te %>%
    summarise(across(.fns = ~mean(sign(.x)==sign(q_test)))) %>%
    mutate(IPW_MR = AR_IW_MR,
           .before = everything()) %>% tibble_row() %>%
    bind_rows(AR2_500, .)

  Value2_500 <- cate_te %>%
    summarise(across(.fns = ~mean(h_test+q_test*(sign(.x))))) %>%
    mutate(IPW_MR = Value_IW_MR,
           .before = everything()) %>% tibble_row() %>%
    bind_rows(Value2_500, .)

  MSE2_500 <- cate_te %>%
    summarise(across(.fns = ~mean((.x-2*q_test)^2))) %>%
    tibble_row() %>%
    bind_rows(MSE2_500, .)

  msg <- paste0("second setting rep ", i, " done")
  if(round(i/10)==i/10){
    print(msg)
  }
}

res2 <- bind_rows(AR2_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")"))),
                  Value2_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")"))),
                  MSE2_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")")))) %>%
  mutate(criteria = c("AR","Value","MSE"), .before = everything()) %>%
  pivot_longer(-1) %>% pivot_wider(names_from = "criteria", values_from = "value")

res2


####Setting 3 Replication #####
nrep <- 100
#Testing Data
set.seed(100)
n_test <- 10000
X_test <- matrix(runif(n_test*5,-1,1),n_test,5)
p_Z_test <- expit(2*X_test[,1])

Z_test <- rbinom(n_test,1,p_Z_test)*2-1
U_test <- bridgedist::rbridge(n_test,1/2)
p_A_test <- expit(2*X_test[,1]+2.5*Z_test-0.5*U_test)
A_test <- rbinom(n_test,1,p_A_test)*2-1
h_test <- 0.5+X_test%*%c(0.5,0.8,0.3,-0.5,0.7)
q_test <- 0.2+X_test%*%c(-0.6,-0.8,0,0,0)
e_test <- rnorm(n_test,0,1)
Y_test <- h_test+q_test*A_test+0.5*U_test+e_test
v_test <- mean(h_test+0.5*U_test+e_test)
mean(h_test+abs(q_test*A_test)+0.5*U_test+e_test)


AR3_500 <- tibble(NULL)
Value3_500 <- tibble(NULL)
MSE3_500 <- tibble(NULL)


for (i in 1:nrep) {
  set.seed(i)
  n <- 500
  X <- matrix(runif(n*5,-1,1),n,5)
  p_Z <- expit(2*X[,1])
  Z <- rbinom(n,1,p_Z)*2-1
  U <- bridgedist::rbridge(n,1/2)
  p_A <- expit(2*X[,1]+2.5*Z-0.5*U)
  A <- rbinom(n,1,p_A)*2-1
  h <- 0.5+X%*%c(0.5,0.8,0.3,-0.5,0.7)
  q <- 0.2+X%*%c(-0.6,-0.8,0,0,0)
  e <- rnorm(n,0,1)
  Y <- h+q*A+0.5*U+e
  a <- (A+1)/2
  #### Estimation of nuisance ####
  #pi_Zhat
  modZ <- glm(factor(Z)~X, family = 'binomial')
  pi_Zhat <- predict.glm(modZ, type = 'response')

  #delta(x): CATE of Z on A
  set.seed(i*6)
  modA <- glm(factor(A)~X+Z, family = 'binomial')
  Dhat_CF <- predict.glm(modA, newdata = data.frame(X=X,Z=1), type = 'response') -
    predict(modA,newdata = data.frame(X=X,Z=-1), type = 'response')
  An <- predict(modA,newdata = data.frame(X=X,Z=-1), type = 'response')

  #conditional means
  modY <- lm(Y~., data = data.frame(Y,X,Z), weights = pi_Zhat)
  Yp <- predict(modY,newdata = data.frame(X,Z=1))
  Yn <- predict(modY,newdata = data.frame(X,Z=-1))

  #augments
  aug_CF <- augments(Y,A,Z,Dhat_CF, pi_Zhat, Yp = Yp, Yn = Yn, An = An)

  #### Estimation of CATE A on Y ####
  #IPW_MR
  set.seed(i*7)
  W2MR <- aug_CF$Ymr
  tuning6 <- best.tune_wsvm(X,factor(sign(W2MR)), weight = abs(W2MR), kernel = 'linear',range = list(cost = 2^(-1:3)))
  AR_IW_MR <- mean(predict(tuning6, newdata = X_test)==sign(q_test))
  Value_IW_MR <- mean(h_test+q_test*(as.numeric(predict(tuning6,newdata = X_test))*2-3)+0.5*U_test+e_test)

  # #RD
  # modq_rd <- rdlearn(X,a+1,Y, p = pi_Zhat, method = "linear")
  # cate_te_rd <- predict(modq_rd, newx = X_test)[,2]-predict(modq_rd, newx = X_test)[,1]
  # AR_RD <- mean(sign(cate_te_rd)==sign(q_test))
  # Value_RD <- mean(q_test*sign(cate_te_rd))+v_test
  # MSE_RD <- mean((cate_te_rd-2*q_test)^2)

  #IV-DL
  fit_IVDL <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_CF, g=0)

  #IV-RDL1
  fit_IVRDL1 <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_CF, g=aug_CF$g)

  #IV-RDL2
  fit_IVRDL2 <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_CF, g=aug_CF$h2)

  ##MR2022
  reg_forest_Ymr <- regression_forest(X,aug_CF$Ymr,tune.parameters = "all",
                                      seed = i*8)
  cate_te_Ymr  <- predict(reg_forest_Ymr, newdata = X_test)$predictions

  #BART
  bart_fit <- bartc(Y,(A+1)/2,X, keepTrees = TRUE,
                    n.samples = 100, n.burn = 15, n.chains = 5,
                    seed = i*9)
  cate_te_BART <- colMeans(predict(bart_fit, newdata = X_test, type = 'icate'))

  #Causal Forest
  cf_fit <- instrumental_forest(X,Y,(A+1)/2,(Z+1)/2, tune.parameters = 'all',
                                seed = i*10)
  cate_te_CF <- predict(cf_fit, newdata = X_test)$predictions/2

  #### Results ####
  coef <- tibble(IVDL = fit_IVDL$coef,
                 IVRDL1 = fit_IVRDL1$coef,
                 IVRDL2 = fit_IVRDL2$coef)
  cate_te <- coef %>%
    reframe(across(everything(),.fns = ~as_vector(cbind(1,X_test)%*%as_vector(.x)))) %>%
    mutate(BART = cate_te_BART,
           CF = cate_te_CF,
           # RD =  cate_te_rd,
           MRIV = cate_te_Ymr, .before = everything())

  AR3_500 <- cate_te %>%
    summarise(across(everything(),.fns = ~mean(sign(.x)==sign(2*q_test)))) %>%
    mutate(IPW_MR = AR_IW_MR, .before = everything()) %>%
    tibble_row() %>%
    bind_rows(AR3_500, .)
  Value3_500 <- cate_te %>%
    summarise(across(everything(),.fns = ~mean(q_test*(sign(.x)))+v_test)) %>%
    mutate(IPW_MR = Value_IW_MR,.before = everything()) %>%
    tibble_row() %>%
    bind_rows(Value3_500, .)
  MSE3_500 <- cate_te %>%
    summarise(across(everything(),.fns = ~mean((.x-2*q_test)^2))) %>%
    tibble_row() %>%
    bind_rows(MSE3_500, .)

  msg <- paste0("third  setting rep ", i, " done")
  if(round(i/10)==i/10){
    print(msg)
  }
}

res3 <- bind_rows(AR3_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")"))),
                  Value3_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")"))),
                  MSE3_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")")))) %>%
  mutate(criteria = c("AR","Value","MSE"), .before = everything()) %>%
  pivot_longer(-1) %>% pivot_wider(names_from = "criteria", values_from = "value")
res3


####Setting 4 Replication #####

nrep <- 100
#Testing Data
set.seed(10)
n_test <- 10000
X_test <- matrix(runif(n_test*5,-1,1),n_test,5)
p_Z_test <- expit(2*X_test[,1])

Z_test <- rbinom(n_test,1,p_Z_test)*2-1
U_test <- bridgedist::rbridge(n_test,1/2)
p_A_test <- expit(2*X_test[,1]+2.5*Z_test-0.5*U_test)
A_test <- rbinom(n_test,1,p_A_test)*2-1
h_test <- 0.5+X_test%*%c(0.5,0.8,0.3,-0.5,0.7)
q_test <- 0.2+X_test%*%c(-0.6,-0.8,0,0,0)
e_test <- rnorm(n_test,0,1)
Y_test <- h_test+q_test*A_test+0.5*U_test+e_test
v_test <- mean(h_test+0.5*U_test+e_test)
mean(h_test+abs(q_test*A_test)+0.5*U_test+e_test)

AR4_500 <- tibble(NULL)
Value4_500 <- tibble(NULL)
MSE4_500 <- tibble(NULL)


for (i in 1:nrep) {
  set.seed(i)
  n <- 500
  X <- matrix(runif(n*5,-1,1),n,5)
  p_Z <- expit(2*X[,1])
  Z <- rbinom(n,1,p_Z)*2-1
  U <- bridgedist::rbridge(n,1/2)
  p_A <- expit(2*X[,1]+2.5*Z-0.5*U)
  A <- rbinom(n,1,p_A)*2-1
  h <- 0.5+X%*%c(0.5,0.8,0.3,-0.5,0.7)
  q <- 0.2+X%*%c(-0.6,-0.8,0,0,0)
  e <- rnorm(n,0,1)
  Y <- h+q*A+0.5*U+e
  a <- (A+1)/2
  #### Estimation of nuisance ####
  #pi_Zhat
  pi_Zhat <- rep(1/2,n)

  #delta(x): CATE of Z on A
  set.seed(i*6)
  modA <- glm(factor(A)~X+Z, family = 'binomial')
  Dhat_LogReg <- predict.glm(modA, newdata = data.frame(X=X,Z=1), type = 'response') -
    predict(modA,newdata = data.frame(X=X,Z=-1), type = 'response')
  Ap <- predict(modA,newdata = data.frame(X=X,Z=1), type = 'response')
  An <- predict(modA,newdata = data.frame(X=X,Z=-1), type = 'response')

  #conditional means
  XZ <- data.frame(X,Z)
  reg_forest_Y.XZ <- regression_forest(XZ,Y,tune.parameters = "all",
                                       seed = i*3)
  Yp <- predict(reg_forest_Y.XZ1, newdata = data.frame(X,Z=1))$predictions
  Yn <- predict(reg_forest_Y.XZ2, newdata = data.frame(X,Z=-1))$predictions

  #augments
  aug_LogReg <- augments(Y,A,Z,Dhat_LogReg, pi_Zhat, Yp = Yp, Yn = Yn,
                         Ap = Ap, An = An)

  #### Estimation of CATE A on Y ####
  #IPW_MR
  set.seed(i*7)
  W2MR <- aug_LogReg$Ymr
  tuning6 <- best.tune_wsvm(X,factor(sign(W2MR)), weight = abs(W2MR), kernel = 'linear',range = list(cost = 2^(-1:3)))
  AR_IW_MR <- mean(predict(tuning6, newdata = X_test)==sign(q_test))
  Value_IW_MR <- mean(h_test+q_test*(as.numeric(predict(tuning6,newdata = X_test))*2-3)+0.5*U_test+e_test)

  # #RD
  # modq_rd <- rdlearn(X,a+1,Y, p = pi_Zhat, method = "linear")
  # cate_te_rd <- predict(modq_rd, newx = X_test)[,2]-predict(modq_rd, newx = X_test)[,1]
  # AR_RD <- mean(sign(cate_te_rd)==sign(q_test))
  # Value_RD <- mean(q_test*sign(cate_te_rd))+v_test
  # MSE_RD <- mean((cate_te_rd-2*q_test)^2)

  #IVDL
  fit_IVDL <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_LogReg, g=0)

  #IV-RDL1
  fit_IVRDL1 <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_LogReg, g=aug_LogReg$g)

  #IV-RDL2
  fit_IVRDL2 <- IVDL.linear(Y,X,Z,pi_Zhat,Dhat_LogReg, g=aug_LogReg$h2)


  ##MR2022
  reg_forest_Ymr <- regression_forest(X,aug_LogReg$Ymr,tune.parameters = "all",
                                      seed = i*8)
  cate_te_Ymr  <- predict(reg_forest_Ymr, newdata = X_test)$predictions

  #BART
  bart_fit <- bartc(Y,(A+1)/2,X, keepTrees = TRUE,
                    n.samples = 100, n.burn = 15, n.chains = 5,
                    seed = i*9)
  cate_te_BART <- colMeans(predict(bart_fit, newdata = X_test, type = 'icate'))

  #Causal Forest
  cf_fit <- instrumental_forest(X,Y,(A+1)/2,(Z+1)/2, tune.parameters = 'all',
                                seed = i*10)
  cate_te_LogReg <- predict(cf_fit, newdata = X_test)$predictions/2

  #### Results ####
  coef <- tibble(IVDL = fit_IVDL$coef,
                 IVRDL1 = fit_IVRDL1$coef,
                 IVRDL2 = fit_IVRDL2$coef)
  cate_te <- coef %>%
    reframe(across(everything(),.fns = ~as_vector(cbind(1,X_test)%*%as_vector(.x)))) %>%
    mutate(BART = cate_te_BART,
           CF = cate_te_LogReg,
           # RD =  cate_te_rd,
           MRIV = cate_te_Ymr, .before = everything())

  AR4_500 <- cate_te %>%
    summarise(across(everything(),.fns = ~mean(sign(.x)==sign(2*q_test)))) %>%
    mutate(IPW_MR = AR_IW_MR, .before = everything()) %>%
    tibble_row() %>%
    bind_rows(AR4_500, .)
  Value4_500 <- cate_te %>%
    summarise(across(everything(),.fns = ~mean(q_test*(sign(.x)))+v_test)) %>%
    mutate(IPW_MR = Value_IW_MR, .before = everything()) %>%
    tibble_row() %>%
    bind_rows(Value4_500, .)
  MSE4_500 <- cate_te %>%
    summarise(across(everything(),.fns = ~mean((.x-2*q_test)^2))) %>%
    tibble_row() %>%
    bind_rows(MSE4_500, .)

  msg <- paste0("fourth setting rep ", i, " done")
  if(round(i/10)==i/10){
    print(msg)
  }
}


res4 <- bind_rows(AR4_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")"))),
                  Value4_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")"))),
                  MSE4_500 %>% as_tibble %>% reframe(across(everything(),.fns = ~ paste0(signif(100*mean(.x),3), "(", round(100*sd(.x)/10, 1), ")")))) %>%
  mutate(criteria = c("AR","Value","MSE"), .before = everything()) %>%
  pivot_longer(-1) %>% pivot_wider(names_from = "criteria", values_from = "value")

res4

####Summary of the Results ####
#Setting1
res1
#Setting2
res2
#Setting3
res3
#Setting4
res4
