#' Adopted from https://github.com/christophergandrud/drlearner
#'
#' Estimate heterogeneous treatment effect using Doubly Robust Estimation
#' (Kennedy 2022) using sample splitting and `ranger` for estimate
#' construction
#'
#' @param X matrix of covariates
#' @param Y numeric vector of outcomes
#' @param W numeric vector of treatment states \[0, 1\]. If a logical vector is
#' supplied, will coerce to numeric with `FALSE = 0` and `TRUE = 1`.
#' @param Xnew new data for CATE prediction
#' @param trunc truncation level
#'
#' @returns A list of estimates needed for best linear projections of the
#' conditional average treatment effect for approximately optimal targeting
#' The list includes observed outcomes (`Y`), treatments (`W`),
#' estimates of E\[Y | X = x\] (`Y.hat`) and E\[W | X = x\] (`W.hat`),
#' and the localized predictions of the causal forest E\[Y_1 - Y_0 | X = x\]
#' (`tau.hat`)
#'
#' @references Kennedy, Edward H. (2022) "Towards optimal doubly robust
#' estimation of heterogeneous causal effects".
#' <https://arxiv.org/abs/2004.14497>.
#'

dr_learner <- function(X, Y, W, Xnew, trunc = 0.02) {
  # Attempting smart coercion
  if (is.logical(W)) {
    W <- ifelse(W == TRUE, 1, 0)
  }

  # Split into 3 samples
  n <- nrow(X)
  stopifnot(
    "X, Y, and W must all be of the same length" =
      n == length(Y) & n == length(W)
  )
  even_split <- floor(n / 3)
  if (n %% 3 != 0) {
      s <- c(rep(1:3, even_split), 1:(n - even_split * 3))
  }
  else
      s <- c(rep(1:3, even_split))
  s <- sample(s)

  # Step 1
  # Propensity scores
  pi.hat <- predict(ranger(y = factor(W[s == 1]), x = X[s == 1, ], probability = T),
                    data = X)$predictions[, 2]

  pi.trunc <- case_when(pi.hat < trunc ~ trunc,
                        pi.hat > 1-trunc ~ 1-trunc,
                        TRUE ~ pi.hat)
  
  # Y given covariates and treatment assignment
  mu0.hat <- predict(ranger(y = Y[W == 0 & s == 2], x = X[W == 0 & s == 2, ]), 
                     data = X)$predictions
  mu1.hat <- predict(ranger(y = Y[W == 1 & s == 2], x = X[W == 1 & s == 2, ]), 
                     data = X)$predictions

  # Step 2
  # Pseudo-regression
  pseudo <- ((W - pi.trunc) / (pi.trunc * (1 - pi.trunc))) * (Y - W * mu1.hat - (1 - W) * mu0.hat) + mu1.hat - mu0.hat
  tau.model <- ranger(y = pseudo[s == 3], x = X[s == 3, ])
  tau.hat <- predict(tau.model, data = X)$predictions
  tau.new <- predict(tau.model, data = Xnew)$predictions

  out <- list(
    Y = Y, W = W, W.hat = pi.trunc, tau.hat = tau.hat, tau.new = tau.new
  )
  return(out)
}


#' Estimate heterogeneous treatment effect using Doubly Robust Estimation
#' (Kennedy 2022) using sample splitting and `ranger` for estimate
#' construction
#'
#' @param X matrix of covariates
#' @param Xt matrix of covariates with treatment indicator
#' @param Xa audit covariates with treatment indicator
#' @param Y numeric vector of outcomes
#' @param Ys numeric vector of scaled outcomes
#' @param Yas numeric vector of scaled audit outcomes
#' @param W numeric vector of treatment states \[0, 1\]. If a logical vector is
#' supplied, will coerce to numeric with `FALSE = 0` and `TRUE = 1`.
#' @param Xnew new data for CATE prediction
#' @param trunc truncation level
#' @param iter max_iter

dr_learnermc <- function(X, Xt, Xa, Y, Ys, Yas, W, Xnew, 
                         trunc = 0.02, eta = 0.5, iter = 5, auditor = "RidgeAuditorFitter") {
  # Attempting smart coercion
  if (is.logical(W)) {
    W <- ifelse(W == TRUE, 1, 0)
  }

  # Split into 3 samples
  n <- nrow(X)
  stopifnot(
    "X, Y, and W must all be of the same length" =
      n == length(Y) & n == length(W)
  )
  even_split <- floor(n / 3)
  if (n %% 3 != 0) {
    s <- c(rep(1:3, even_split), 1:(n - even_split * 3))
  }
  else
    s <- c(rep(1:3, even_split))
  s <- sample(s)
  
  # Step 1
  # Propensity scores
  pi <- ranger(y = factor(W[s == 1]), x = X[s == 1, ], probability = T)
  pi.hat <- predict(pi, data = X)$predictions[, 2]
  
  pi.trunc <- case_when(pi.hat < trunc ~ trunc,
                        pi.hat > 1-trunc ~ 1-trunc,
                        TRUE ~ pi.hat)
  
  # Y given covariates and treatment assignment
  mu0 <- ranger(y = Ys[W == 0 & s == 2], x = Xt[W == 0 & s == 2, ])
  mu1 <- ranger(y = Ys[W == 1 & s == 2], x = Xt[W == 1 & s == 2, ])
  
  # Multicalibrate on audit data
  init_preds = function(data) {preds <- predict(mu1, data)$predictions}
  drlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                              auditor_fitter = auditor,
                              alpha = 1e-06,
                              weight_degree = 2,
                              eta = eta,
                              max_iter = iter)
  drlearner_t_mc$multicalibrate(Xa[Xa$T == 1, ], Yas[Xa$T == 1])
  
  yp_drlearner_t_mc <- drlearner_t_mc$predict_probs(Xt)
  mu1.hat <- rev_scale(yp_drlearner_t_mc, label = -2*min(Y) + 2*Y)
  
  init_preds = function(data) {preds <- predict(mu0, data)$predictions}
  drlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                               auditor_fitter = auditor,
                               alpha = 1e-06,
                               weight_degree = 2,
                               eta = eta,
                               max_iter = iter)
  drlearner_ut_mc$multicalibrate(Xa[Xa$T == 0, ], Yas[Xa$T == 0])
  
  yp_drlearner_ut_mc <- drlearner_ut_mc$predict_probs(Xt)
  mu0.hat <- rev_scale(yp_drlearner_ut_mc, label = -2*min(Y) + 2*Y)

  # Step 2
  # Pseudo-regression
  pseudo <- ((W - pi.trunc) / (pi.trunc * (1 - pi.trunc))) * (Y - W * mu1.hat - (1 - W) * mu0.hat) + mu1.hat - mu0.hat
  tau.model <- ranger(y = pseudo[s == 3], x = X[s == 3, ])
  tau.hat <- predict(tau.model, data = X)$predictions
  tau.new <- predict(tau.model, data = Xnew)$predictions
  
  out <- list(
    Y = Y, W = W, W.hat = pi.trunc, tau.hat = tau.hat, tau.new = tau.new
  )
  return(out)
}



#' Estimate heterogeneous treatment effect using Doubly Robust Estimation
#' (Kennedy 2022) using sample splitting and `ranger` for estimate
#' construction
#'
#' @param X matrix of covariates
#' @param Xt matrix of covariates with treatment indicator
#' @param Xa audit covariates with treatment indicator
#' @param Y numeric vector of outcomes
#' @param Ys numeric vector of scaled outcomes
#' @param Yas numeric vector of scaled audit outcomes
#' @param W numeric vector of treatment states \[0, 1\]. If a logical vector is
#' supplied, will coerce to numeric with `FALSE = 0` and `TRUE = 1`.
#' @param Wa numeric vector of audit treatment states
#' @param Xnew new data for CATE prediction
#' @param trunc truncation level
#' @param iter max_iter



dr_learnermc2 <- function(X, Xt, Xa, Y, Ys, Ya, Yas, W, Wa, Xnew, 
                          trunc = 0.02, eta = 0.5, iter = 5, auditor = "RidgeAuditorFitter") {
  # Attempting smart coercion
  if (is.logical(W)) {
    W <- ifelse(W == TRUE, 1, 0)
  }
  
  # Split into 3 samples
  n <- nrow(X)
  stopifnot(
    "X, Y, and W must all be of the same length" =
      n == length(Y) & n == length(W)
  )
  even_split <- floor(n / 3)
  if (n %% 3 != 0) {
    s <- c(rep(1:3, even_split), 1:(n - even_split * 3))
  }
  else
    s <- c(rep(1:3, even_split))
  s <- sample(s)
  
  # Step 1
  # Propensity scores
  pi <- ranger(y = factor(W[s == 1]), x = X[s == 1, ], probability = T)
  pi.hat <- predict(pi, data = X)$predictions[, 2]
  pi.audit <- predict(pi, data = Xa)$predictions[, 2]
  
  pi.trunc <- case_when(pi.hat < trunc ~ trunc,
                        pi.hat > 1-trunc ~ 1-trunc,
                        TRUE ~ pi.hat)
  pi.trunc.a <- case_when(pi.audit < trunc ~ trunc,
                          pi.audit > 1-trunc ~ 1-trunc,
                          TRUE ~ pi.audit)
  
  # Y given covariates and treatment assignment
  mu0 <- ranger(y = Ys[W == 0 & s == 2], x = Xt[W == 0 & s == 2, ])
  mu1 <- ranger(y = Ys[W == 1 & s == 2], x = Xt[W == 1 & s == 2, ])

  # Multicalibrate on audit data
  init_preds = function(data) {preds <- predict(mu1, data)$predictions}
  drlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                               auditor_fitter = auditor,
                               alpha = 1e-06,
                               weight_degree = 2,
                               eta = eta,
                               max_iter = iter)
  drlearner_t_mc$multicalibrate(Xa[Xa$T == 1, ], Yas[Xa$T == 1])
  
  yp_drlearner_t_mc <- drlearner_t_mc$predict_probs(Xt)
  ypa_drlearner_t_mc <- drlearner_t_mc$predict_probs(Xa)
  mu1.hat <- rev_scale(yp_drlearner_t_mc, label = -2*min(Y) + 2*Y)
  mu1.audit <- rev_scale(ypa_drlearner_t_mc, label = -2*min(Y) + 2*Y)
  
  init_preds = function(data) {preds <- predict(mu0, data)$predictions}
  drlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                                auditor_fitter = auditor,
                                alpha = 1e-06,
                                weight_degree = 2,
                                eta = eta,
                                max_iter = iter)
  drlearner_ut_mc$multicalibrate(Xa[Xa$T == 0, ], Yas[Xa$T == 0])
  
  yp_drlearner_ut_mc <- drlearner_ut_mc$predict_probs(Xt)
  ypa_drlearner_ut_mc <- drlearner_ut_mc$predict_probs(Xa)
  mu0.hat <- rev_scale(yp_drlearner_ut_mc, label = -2*min(Y) + 2*Y)
  mu0.audit <- rev_scale(ypa_drlearner_ut_mc, label = -2*min(Y) + 2*Y)
  
  # Step 2
  # Pseudo-regression
  pseudo <- ((W - pi.trunc) / (pi.trunc * (1 - pi.trunc))) * (Y - W * mu1.hat - (1 - W) * mu0.hat) + mu1.hat - mu0.hat
  pseudo.scaled <- scale(pseudo, label = pseudo*2)
  tau.model <- ranger(y = pseudo.scaled[s == 3], x = X[s == 3, ])
  
  # Multicalibrate on audit data
  pseudo.audit <- ((Wa - pi.trunc.a) / (pi.trunc.a * (1 - pi.trunc.a))) * (Ya - Wa * mu1.audit - (1 - Wa) * mu0.audit) + mu1.audit - mu0.audit
  pseudo.audit.scaled <- scale(pseudo.audit, label = pseudo*2)
  pseudo.audit.scaled <- case_when(pseudo.audit.scaled < 0 ~ 0,
                                   pseudo.audit.scaled > 1 ~ 1,
                                   TRUE ~ pseudo.audit.scaled)
  
  init_preds = function(data) {preds <- predict(tau.model, data)$predictions}
  drlearner_t_mc2 = MCBoost$new(init_predictor = init_preds,
                                auditor_fitter = auditor,
                                alpha = 1e-06,
                                weight_degree = 2,
                                eta = eta,
                                max_iter = iter)
  drlearner_t_mc2$multicalibrate(Xa[,-T], pseudo.audit.scaled)
  
  yp_drlearner_t_mc2 <- drlearner_t_mc2$predict_probs(Xt)
  tau.hat <- rev_scale(yp_drlearner_t_mc2, label = pseudo*2)
  
  yp_drlearner_t_mc2 <- drlearner_t_mc2$predict_probs(Xnew)
  tau.new <- rev_scale(yp_drlearner_t_mc2, label = pseudo*2)
  
  out <- list(
    Y = Y, W = W, W.hat = pi.trunc, tau.hat = tau.hat, tau.new = tau.new
  )
  return(out)
}



#' Estimate heterogeneous treatment effect using Doubly Robust Estimation
#' (Kennedy 2022) using sample splitting and `ranger` for estimate
#' construction
#'
#' @param X matrix of covariates
#' @param Xt matrix of covariates with treatment indicator
#' @param Xa audit covariates with treatment indicator
#' @param Y numeric vector of outcomes
#' @param Ya numeric vector of audit outcomes
#' @param W numeric vector of treatment states \[0, 1\]. If a logical vector is
#' supplied, will coerce to numeric with `FALSE = 0` and `TRUE = 1`.
#' @param Wa numeric vector of audit treatment states
#' @param Xnew new data for CATE prediction
#' @param trunc truncation level
#' @param iter max_iter

dr_learnermc3 <- function(X, Xt, Xa, Y, Ya, W, Wa, Xnew, 
                          trunc = 0.02, eta = 0.5, iter = 5, auditor = "RidgeAuditorFitter") {
  # Attempting smart coercion
  if (is.logical(W)) {
    W <- ifelse(W == TRUE, 1, 0)
  }
  
  # Split into 3 samples
  n <- nrow(X)
  stopifnot(
    "X, Y, and W must all be of the same length" =
      n == length(Y) & n == length(W)
  )
  even_split <- floor(n / 3)
  if (n %% 3 != 0) {
    s <- c(rep(1:3, even_split), 1:(n - even_split * 3))
  }
  else
    s <- c(rep(1:3, even_split))
  s <- sample(s)
  
  # Step 1
  # Propensity scores
  pi <- ranger(y = factor(W[s == 1]), x = X[s == 1, ], probability = T)
  pi.hat <- predict(pi, data = X)$predictions[, 2]
  pi.audit <- predict(pi, data = Xa)$predictions[, 2]
  
  pi.trunc <- case_when(pi.hat < trunc ~ trunc,
                        pi.hat > 1-trunc ~ 1-trunc,
                        TRUE ~ pi.hat)
  pi.trunc.a <- case_when(pi.audit < trunc ~ trunc,
                          pi.audit > 1-trunc ~ 1-trunc,
                          TRUE ~ pi.audit)
  
  # Y given covariates and treatment assignment
  mu0 <- ranger(y = Y[W == 0 & s == 2], x = Xt[W == 0 & s == 2, ])
  mu1 <- ranger(y = Y[W == 1 & s == 2], x = Xt[W == 1 & s == 2, ])
  
  mu0.hat <- predict(mu0, data = Xt)$predictions
  mu1.hat <- predict(mu1, data = Xt)$predictions
  mu0.audit <- predict(mu0, data = Xa)$predictions
  mu1.audit <- predict(mu1, data = Xa)$predictions
  
  # Step 2
  # Pseudo-regression
  pseudo <- ((W - pi.trunc) / (pi.trunc * (1 - pi.trunc))) * (Y - W * mu1.hat - (1 - W) * mu0.hat) + mu1.hat - mu0.hat
  pseudo.scaled <- scale(pseudo, label = pseudo*2)
  tau.model <- ranger(y = pseudo.scaled[s == 3], x = X[s == 3, ])
  
  # Multicalibrate on audit data
  pseudo.audit <- ((Wa - pi.trunc.a) / (pi.trunc.a * (1 - pi.trunc.a))) * (Ya - Wa * mu1.audit - (1 - Wa) * mu0.audit) + mu1.audit - mu0.audit
  pseudo.audit.scaled <- scale(pseudo.audit, label = pseudo*2)
  pseudo.audit.scaled <- case_when(pseudo.audit.scaled < 0 ~ 0,
                                   pseudo.audit.scaled > 1 ~ 1,
                                   TRUE ~ pseudo.audit.scaled)
  
  init_preds = function(data) {preds <- predict(tau.model, data)$predictions}
  drlearner_t_mc2 = MCBoost$new(init_predictor = init_preds,
                                auditor_fitter = auditor,
                                alpha = 1e-06,
                                weight_degree = 2,
                                eta = eta,
                                max_iter = iter)
  drlearner_t_mc2$multicalibrate(Xa[,-T], pseudo.audit.scaled)
  
  yp_drlearner_t_mc2 <- drlearner_t_mc2$predict_probs(Xt)
  tau.hat <- rev_scale(yp_drlearner_t_mc2, label = pseudo*2)
  
  yp_drlearner_t_mc2 <- drlearner_t_mc2$predict_probs(Xnew)
  tau.new <- rev_scale(yp_drlearner_t_mc2, label = pseudo*2)
  
  out <- list(
    Y = Y, W = W, W.hat = pi.trunc, tau.hat = tau.hat, tau.new = tau.new
  )
  return(out)
}



#' Estimate heterogeneous treatment effect using Doubly Robust Estimation
#' (Kennedy 2022) using sample splitting and `ranger` for estimate
#' construction
#'
#' @param X matrix of covariates
#' @param Y numeric vector of outcomes
#' @param W numeric vector of treatment states \[0, 1\]. If a logical vector is
#' supplied, will coerce to numeric with `FALSE = 0` and `TRUE = 1`.
#' @param Xnew new data for CATE prediction
#' @param trunc truncation level
#'
#' @returns A list of estimates needed for best linear projections of the
#' conditional average treatment effect for approximately optimal targeting
#' The list includes observed outcomes (`Y`), treatments (`W`),
#' estimates of E\[Y | X = x\] (`Y.hat`) and E\[W | X = x\] (`W.hat`),
#' and the localized predictions of the causal forest E\[Y_1 - Y_0 | X = x\]
#' (`tau.hat`)
#'
#' @references Kennedy, Edward H. (2022) "Towards optimal doubly robust
#' estimation of heterogeneous causal effects".
#' <https://arxiv.org/abs/2004.14497>.
#'

dr_learner_grf <- function(X, Y, W, Xnew, trunc = 0.02) {
  # Attempting smart coercion
  if (is.logical(W)) {
    W <- ifelse(W == TRUE, 1, 0)
  }
  
  # Split into 3 samples
  n <- nrow(X)
  stopifnot(
    "X, Y, and W must all be of the same length" =
      n == length(Y) & n == length(W)
  )
  even_split <- floor(n / 3)
  if (n %% 3 != 0) {
    s <- c(rep(1:3, even_split), 1:(n - even_split * 3))
  }
  else
    s <- c(rep(1:3, even_split))
  s <- sample(s)
  
  # Step 1
  # Propensity scores
  pi <- probability_forest(Y = factor(W[s == 1]), X = X[s == 1, ])
  pi.hat <- predict(pi, X)$predictions[, 2]
  
  pi.trunc <- case_when(pi.hat < trunc ~ trunc,
                        pi.hat > 1-trunc ~ 1-trunc,
                        TRUE ~ pi.hat)
  
  # Y given covariates and treatment assignment
  mu0 <- regression_forest(Y = Y[W == 0 & s == 2], X = X[W == 0 & s == 2, ])
  mu1 <- regression_forest(Y = Y[W == 1 & s == 2], X = X[W == 1 & s == 2, ])
  
  mu0.hat <- predict(mu0, X)$predictions
  mu1.hat <- predict(mu1, X)$predictions
  
  # Step 2
  # Pseudo-regression
  pseudo <- ((W - pi.trunc) / (pi.trunc * (1 - pi.trunc))) * (Y - W * mu1.hat - (1 - W) * mu0.hat) + mu1.hat - mu0.hat
  tau.model <- regression_forest(Y = pseudo[s == 3], X = X[s == 3, ])
  tau.hat <- predict(tau.model, X)$predictions
  tau.new <- predict(tau.model, Xnew)$predictions
  
  out <- list(
    Y = Y, W = W, W.hat = pi.trunc, tau.hat = tau.hat, tau.new = tau.new
  )
  return(out)
}



#' Estimate heterogeneous treatment effect using Doubly Robust Estimation
#' (Kennedy 2022) using sample splitting and `ranger` for estimate
#' construction
#'
#' @param X matrix of covariates
#' @param Xt matrix of covariates with treatment indicator
#' @param Xa audit covariates with treatment indicator
#' @param Y numeric vector of outcomes
#' @param Ys numeric vector of scaled outcomes
#' @param Yas numeric vector of scaled audit outcomes
#' @param W numeric vector of treatment states \[0, 1\]. If a logical vector is
#' supplied, will coerce to numeric with `FALSE = 0` and `TRUE = 1`.
#' @param Xnew new data for CATE prediction
#' @param trunc truncation level
#' @param iter max_iter

dr_learnermc_grf <- function(X, Xt, Xa, Y, Ys, Yas, W, Xnew, 
                             trunc = 0.02, eta = 0.5, iter = 5, auditor = "RidgeAuditorFitter") {
  # Attempting smart coercion
  if (is.logical(W)) {
    W <- ifelse(W == TRUE, 1, 0)
  }
  
  # Split into 3 samples
  n <- nrow(X)
  stopifnot(
    "X, Y, and W must all be of the same length" =
      n == length(Y) & n == length(W)
  )
  even_split <- floor(n / 3)
  if (n %% 3 != 0) {
    s <- c(rep(1:3, even_split), 1:(n - even_split * 3))
  }
  else
    s <- c(rep(1:3, even_split))
  s <- sample(s)
  
  # Step 1
  # Propensity scores
  pi <- probability_forest(Y = factor(W[s == 1]), X = X[s == 1, ])
  pi.hat <- predict(pi, X)$predictions[, 2]
  
  pi.trunc <- case_when(pi.hat < trunc ~ trunc,
                        pi.hat > 1-trunc ~ 1-trunc,
                        TRUE ~ pi.hat)
  
  # Y given covariates and treatment assignment
  mu0 <- regression_forest(Y = Ys[W == 0 & s == 2], X = Xt[W == 0 & s == 2, ])
  mu1 <- regression_forest(Y = Ys[W == 1 & s == 2], X = Xt[W == 1 & s == 2, ])
  
  # Multicalibrate on audit data
  init_preds = function(data) {preds <- predict(mu1, data)$predictions}
  drlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                               auditor_fitter = auditor,
                               alpha = 1e-06,
                               weight_degree = 2,
                               eta = eta,
                               max_iter = iter)
  drlearner_t_mc$multicalibrate(Xa[Xa$T == 1, ], Yas[Xa$T == 1])
  
  yp_drlearner_t_mc <- drlearner_t_mc$predict_probs(Xt)
  mu1.hat <- rev_scale(yp_drlearner_t_mc, label = -2*min(Y) + 2*Y)
  
  init_preds = function(data) {preds <- predict(mu0, data)$predictions}
  drlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                                auditor_fitter = auditor,
                                alpha = 1e-06,
                                weight_degree = 2,
                                eta = eta,
                                max_iter = iter)
  drlearner_ut_mc$multicalibrate(Xa[Xa$T == 0, ], Yas[Xa$T == 0])
  
  yp_drlearner_ut_mc <- drlearner_ut_mc$predict_probs(Xt)
  mu0.hat <- rev_scale(yp_drlearner_ut_mc, label = -2*min(Y) + 2*Y)
  
  # Step 2
  # Pseudo-regression
  pseudo <- ((W - pi.trunc) / (pi.trunc * (1 - pi.trunc))) * (Y - W * mu1.hat - (1 - W) * mu0.hat) + mu1.hat - mu0.hat
  tau.model <- regression_forest(Y = pseudo[s == 3], X = X[s == 3, ])
  tau.hat <- predict(tau.model, X)$predictions
  tau.new <- predict(tau.model, Xnew)$predictions
  
  out <- list(
    Y = Y, W = W, W.hat = pi.trunc, tau.hat = tau.hat, tau.new = tau.new
  )
  return(out)
}



#' Estimate heterogeneous treatment effect using Doubly Robust Estimation
#' (Kennedy 2022) using sample splitting and `ranger` for estimate
#' construction
#'
#' @param X matrix of covariates
#' @param Xt matrix of covariates with treatment indicator
#' @param Xa audit covariates with treatment indicator
#' @param Y numeric vector of outcomes
#' @param Ya numeric vector of audit outcomes
#' @param W numeric vector of treatment states \[0, 1\]. If a logical vector is
#' supplied, will coerce to numeric with `FALSE = 0` and `TRUE = 1`.
#' @param Wa numeric vector of audit treatment states
#' @param Xnew new data for CATE prediction
#' @param trunc truncation level
#' @param iter max_iter

dr_learnermc2_grf <- function(X, Xt, Xa, Y, Ys, Ya, Yas, W, Wa, Xnew, 
                              trunc = 0.02, eta = 0.5, iter = 5, auditor = "RidgeAuditorFitter") {
  # Attempting smart coercion
  if (is.logical(W)) {
    W <- ifelse(W == TRUE, 1, 0)
  }
  
  # Split into 3 samples
  n <- nrow(X)
  stopifnot(
    "X, Y, and W must all be of the same length" =
      n == length(Y) & n == length(W)
  )
  even_split <- floor(n / 3)
  if (n %% 3 != 0) {
    s <- c(rep(1:3, even_split), 1:(n - even_split * 3))
  }
  else
    s <- c(rep(1:3, even_split))
  s <- sample(s)
  
  # Step 1
  # Propensity scores
  pi <- probability_forest(Y = factor(W[s == 1]), X = X[s == 1, ])
  pi.hat <- predict(pi, X)$predictions[, 2]
  pi.audit <- predict(pi, Xa[,-T])$predictions[, 2]
  
  pi.trunc <- case_when(pi.hat < trunc ~ trunc,
                        pi.hat > 1-trunc ~ 1-trunc,
                        TRUE ~ pi.hat)
  pi.trunc.a <- case_when(pi.audit < trunc ~ trunc,
                          pi.audit > 1-trunc ~ 1-trunc,
                          TRUE ~ pi.audit)
  
  # Y given covariates and treatment assignment
  mu0 <- regression_forest(Y = Ys[W == 0 & s == 2], X = Xt[W == 0 & s == 2, ])
  mu1 <- regression_forest(Y = Ys[W == 1 & s == 2], X = Xt[W == 1 & s == 2, ])
  
  # Multicalibrate on audit data
  init_preds = function(data) {preds <- predict(mu1, data)$predictions}
  drlearner_t_mc = MCBoost$new(init_predictor = init_preds,
                               auditor_fitter = auditor,
                               alpha = 1e-06,
                               weight_degree = 2,
                               eta = eta,
                               max_iter = iter)
  drlearner_t_mc$multicalibrate(Xa[Xa$T == 1, ], Yas[Xa$T == 1])
  
  yp_drlearner_t_mc <- drlearner_t_mc$predict_probs(Xt)
  ypa_drlearner_t_mc <- drlearner_t_mc$predict_probs(Xa)
  mu1.hat <- rev_scale(yp_drlearner_t_mc, label = -2*min(Y) + 2*Y)
  mu1.audit <- rev_scale(ypa_drlearner_t_mc, label = -2*min(Y) + 2*Y)
  
  init_preds = function(data) {preds <- predict(mu0, data)$predictions}
  drlearner_ut_mc = MCBoost$new(init_predictor = init_preds,
                                auditor_fitter = auditor,
                                alpha = 1e-06,
                                weight_degree = 2,
                                eta = eta,
                                max_iter = iter)
  drlearner_ut_mc$multicalibrate(Xa[Xa$T == 0, ], Yas[Xa$T == 0])
  
  yp_drlearner_ut_mc <- drlearner_ut_mc$predict_probs(Xt)
  ypa_drlearner_ut_mc <- drlearner_ut_mc$predict_probs(Xa)
  mu0.hat <- rev_scale(yp_drlearner_ut_mc, label = -2*min(Y) + 2*Y)
  mu0.audit <- rev_scale(ypa_drlearner_ut_mc, label = -2*min(Y) + 2*Y)
  
  # Step 2
  # Pseudo-regression
  pseudo <- ((W - pi.trunc) / (pi.trunc * (1 - pi.trunc))) * (Y - W * mu1.hat - (1 - W) * mu0.hat) + mu1.hat - mu0.hat
  pseudo.scaled <- scale(pseudo, label = pseudo*2)
  tau.model <- regression_forest(Y = pseudo.scaled[s == 3], X = X[s == 3, ])
  
  # Multicalibrate on audit data
  pseudo.audit <- ((Wa - pi.trunc.a) / (pi.trunc.a * (1 - pi.trunc.a))) * (Ya - Wa * mu1.audit - (1 - Wa) * mu0.audit) + mu1.audit - mu0.audit
  pseudo.audit.scaled <- scale(pseudo.audit, label = pseudo*2)
  pseudo.audit.scaled <- case_when(pseudo.audit.scaled < 0 ~ 0,
                                   pseudo.audit.scaled > 1 ~ 1,
                                   TRUE ~ pseudo.audit.scaled)
  
  init_preds = function(data) {preds <- predict(tau.model, data)$predictions}
  drlearner_t_mc2 = MCBoost$new(init_predictor = init_preds,
                                auditor_fitter = auditor,
                                alpha = 1e-06,
                                weight_degree = 2,
                                eta = eta,
                                max_iter = iter)
  drlearner_t_mc2$multicalibrate(Xa[,-T], pseudo.audit.scaled)
  
  yp_drlearner_t_mc2 <- drlearner_t_mc2$predict_probs(X)
  tau.hat <- rev_scale(yp_drlearner_t_mc2, label = pseudo*2)
  
  yp_drlearner_t_mc2 <- drlearner_t_mc2$predict_probs(Xnew)
  tau.new <- rev_scale(yp_drlearner_t_mc2, label = pseudo*2)
  
  out <- list(
    Y = Y, W = W, W.hat = pi.trunc, tau.hat = tau.hat, tau.new = tau.new
  )
  return(out)
}



#' Estimate heterogeneous treatment effect using Doubly Robust Estimation
#' (Kennedy 2022) using sample splitting and `ranger` for estimate
#' construction
#'
#' @param X matrix of covariates
#' @param Xt matrix of covariates with treatment indicator
#' @param Xa audit covariates with treatment indicator
#' @param Y numeric vector of outcomes
#' @param Ya numeric vector of audit outcomes
#' @param W numeric vector of treatment states \[0, 1\]. If a logical vector is
#' supplied, will coerce to numeric with `FALSE = 0` and `TRUE = 1`.
#' @param Wa numeric vector of audit treatment states
#' @param Xnew new data for CATE prediction
#' @param trunc truncation level
#' @param iter max_iter

dr_learnermc3_grf <- function(X, Xt, Xa, Y, Ya, W, Wa, Xnew, 
                              trunc = 0.02, eta = 0.5, iter = 5, auditor = "RidgeAuditorFitter") {
  # Attempting smart coercion
  if (is.logical(W)) {
    W <- ifelse(W == TRUE, 1, 0)
  }
  
  # Split into 3 samples
  n <- nrow(X)
  stopifnot(
    "X, Y, and W must all be of the same length" =
      n == length(Y) & n == length(W)
  )
  even_split <- floor(n / 3)
  if (n %% 3 != 0) {
    s <- c(rep(1:3, even_split), 1:(n - even_split * 3))
  }
  else
    s <- c(rep(1:3, even_split))
  s <- sample(s)
  
  # Step 1
  # Propensity scores
  pi <- probability_forest(Y = factor(W[s == 1]), X = X[s == 1, ])
  pi.hat <- predict(pi, X)$predictions[, 2]
  pi.audit <- predict(pi, Xa[,-T])$predictions[, 2]
  
  pi.trunc <- case_when(pi.hat < trunc ~ trunc,
                        pi.hat > 1-trunc ~ 1-trunc,
                        TRUE ~ pi.hat)
  pi.trunc.a <- case_when(pi.audit < trunc ~ trunc,
                          pi.audit > 1-trunc ~ 1-trunc,
                          TRUE ~ pi.audit)
  
  # Y given covariates and treatment assignment
  mu0 <- regression_forest(Y = Y[W == 0 & s == 2], X = Xt[W == 0 & s == 2, ])
  mu1 <- regression_forest(Y = Y[W == 1 & s == 2], X = Xt[W == 1 & s == 2, ])
  
  mu0.hat <- predict(mu0, Xt)$predictions
  mu1.hat <- predict(mu1, Xt)$predictions
  mu0.audit <- predict(mu0, Xa)$predictions
  mu1.audit <- predict(mu1, Xa)$predictions
  
  # Step 2
  # Pseudo-regression
  pseudo <- ((W - pi.trunc) / (pi.trunc * (1 - pi.trunc))) * (Y - W * mu1.hat - (1 - W) * mu0.hat) + mu1.hat - mu0.hat
  pseudo.scaled <- scale(pseudo, label = pseudo*2)
  tau.model <- regression_forest(Y = pseudo.scaled[s == 3], X = X[s == 3, ])
  
  # Multicalibrate on audit data
  pseudo.audit <- ((Wa - pi.trunc.a) / (pi.trunc.a * (1 - pi.trunc.a))) * (Ya - Wa * mu1.audit - (1 - Wa) * mu0.audit) + mu1.audit - mu0.audit
  pseudo.audit.scaled <- scale(pseudo.audit, label = pseudo*2)
  pseudo.audit.scaled <- case_when(pseudo.audit.scaled < 0 ~ 0,
                                   pseudo.audit.scaled > 1 ~ 1,
                                   TRUE ~ pseudo.audit.scaled)
  
  init_preds = function(data) {preds <- predict(tau.model, data)$predictions}
  drlearner_t_mc2 = MCBoost$new(init_predictor = init_preds,
                                auditor_fitter = auditor,
                                alpha = 1e-06,
                                weight_degree = 2,
                                eta = eta,
                                max_iter = iter)
  drlearner_t_mc2$multicalibrate(Xa[,-T], pseudo.audit.scaled)
  
  yp_drlearner_t_mc2 <- drlearner_t_mc2$predict_probs(X)
  tau.hat <- rev_scale(yp_drlearner_t_mc2, label = pseudo*2)
  
  yp_drlearner_t_mc2 <- drlearner_t_mc2$predict_probs(Xnew)
  tau.new <- rev_scale(yp_drlearner_t_mc2, label = pseudo*2)
  
  out <- list(
    Y = Y, W = W, W.hat = pi.trunc, tau.hat = tau.hat, tau.new = tau.new
  )
  return(out)
}
