# Dose-response function (drf)
# Linear
drf_linear <- function(x) {
  x
}

# Polinomial
drf_poly <- function(x) {
  0.2 * (x - 5)^2 - x - 5
}

# Exponential dose-response function
drf_exp <- function(x) {
  log(1 + exp(x) / (x + .1)) - log(11)
}

# Sinusoidal dose-response function
drf_sin <- function(x) {
  5 * sin(x) + x
}

# helper function
sigmoid <- function(x) {
  1 / (1 + exp(-x))
}

# DGP
# t = 20 * Beta(sigmoid(X*beta_X + Z *beta_Z)) + nu
biased_treatment_fun <- function(X, Z, beta_x, beta_z) {
  stopifnot(nrow(X) == nrow(Z))

  n <- nrow(X)
  # treatment assignment
  t <- 20 * pbeta(sigmoid(X %*% beta_x + Z %*% beta_z), 2, 3) + runif(n, -.1, .1)

  # discretize
  sapply(t, function(x) max(0, min(19, floor(x))))
}

# randomly assigned
unbiased_treatment_fun <- function(n) {
  sample(20, size = n, replace = TRUE) - 1
}

#' Outcome function
#'
#' @param t Treatment
#' @param X Confounder matrix
#' @param U Outcome specific adjustment covariates matrix
#' @param drf Type of dose response function
#' @return Outcome
outcome_fun <- function(t, X, U, beta_x, beta_u, drf) {
  stopifnot(nrow(X) == nrow(U))
  n <- nrow(X)
  err <- runif(nrow(X), -.1, .1)

  do.call(drf, list(t)) + 0.2 * (X[,1]^2 + X[, 4]) * t + cbind(X, U) %*% c(beta_x, beta_u) + err
}


#' @param beta_tx
#' @param beta_yx
#' @param beta_z
#' @param beta_u
#' @param drf
#' @return
gen_data <- function(n, dim_x, dim_z, dim_u, drf) {
  n_train <- floor(n * .7)
  n_test <- n - n_train

  beta_x <- runif(dim_x, -1, 1) * sample(0:1, dim_x, replace = T, prob = c(.2, .8))
  beta_z <- runif(dim_z, -1, 1) * sample(0:1, dim_z, replace = T, prob = c(.2, .8))
  beta_u <- runif(dim_u, -1, 1) * sample(0:1, dim_u, replace = T, prob = c(.2, .8))

  X_train <- matrix(rnorm(n_train * dim_x, 0, 1), n_train, dim_x)
  Z_train <- matrix(rnorm(n_train * dim_z, 0, 1), n_train, dim_z)
  U_train <- matrix(rnorm(n_train * dim_u, 0, 1), n_train, dim_u)

  # triain
  t_train <- biased_treatment_fun(X_train, Z_train, beta_x, beta_z)
  # t_train <- unbiased_treatment_fun(n_train)
  y_train <- outcome_fun(t_train, X_train, U_train, beta_x, beta_u, drf)
  y0_train <- outcome_fun(rep(0, n_train), X_train, U_train, beta_x, beta_u, drf)

  train <- cbind(
    y = y_train - y0_train,
    t = t_train,
    X = as.data.frame(X_train),
    Z = as.data.frame(Z_train),
    U = as.data.frame(U_train)
  )

  # test
  X_test <- matrix(rnorm(n_test * dim_x, 0, 1), n_test, dim_x)
  Z_test <- matrix(rnorm(n_test * dim_z, 0, 1), n_test, dim_z)
  U_test <- matrix(rnorm(n_test * dim_u, 0, 1), n_test, dim_u)

  t_test <- unbiased_treatment_fun(n_test)
  y_test <- outcome_fun(t_test, X_test, U_test, beta_x, beta_u, drf)
  y0_test <- outcome_fun(rep(0, n_test), X_test, U_test, beta_x, beta_u, drf)

  test <- cbind(
    y = y_test - y0_test,
    t = t_test,
    X = as.data.frame(X_test),
    Z = as.data.frame(Z_test),
    U = as.data.frame(U_test)
  )

  list(train, test)
}
