library(lava)
source("infoMatrix.R")

## use quadratic function
eff_model_true <- function(x, dose){
  gamma0 <- -0.34 - 0.17*x
  gamma1 <- 4.3 - 1.09*x
  gamma2 <- -3.31 + 0.09*x
  eta <- gamma0 + gamma1*dose + gamma2*dose^2
  
  exp(eta)/(1+exp(eta))
}

tox_model_true <- function(x, dose){
  beta0 <- -2.43 + 1.71*x
  beta1 <- 2.24 + 0.73*x
  eta <- beta0 + beta1*dose
  exp(eta)/(1+exp(eta))
}

tox_model_par <- function(x, dose_history, pars){
  xMat <- cbind(1, x)
  b0 <- pars[1:2]
  b1 <- pars[3:4]
  
  beta0 <- xMat%*%b0
  beta1 <- xMat%*%b1
  
  eta <- beta0 + beta1*dose_history
  pb <- exp(eta)/(1+exp(eta))
  return(list(pb = pb, eta = eta, xMat = xMat))
}

eff_model_par <- function(x, dose_history, pars){
  xMat <- cbind(1, x)
  c0 <- pars[1:2]
  c1 <- pars[3:4]
  c2 <- pars[5:6]
  gamma0 <- xMat%*%c0
  gamma1 <- xMat%*%c1
  gamma2 <- xMat%*%c2
  
  eta <- gamma0 + gamma1*dose_history + gamma2*dose_history^2
  pb <- exp(eta)/(1+exp(eta))
  return(list(pb = pb, eta = eta, xMat = xMat))
}

myGLM <- function(x_input, y_input, start){
  #browser()
  explore_flag <- 0
  l <- length(start)
  b <- glm(y_input ~ x_input, family = binomial(link = 'logit'),
               start = start)$coefficients
  b[is.na(b)] <- 100
  if(ncol(x_input) == 1){
    if(sum(abs(b)>100) >0){
      b <- rnorm(2, 0, 0.1)
      explore_flag <- 1
    }
  }else{
    if(sum(abs(b)>20) >0){
      mysd <- function(z) sqrt(sum((z-mean(z))^2)/length(z))
      sx <- scale(x_input, scale = apply(x_input, 2, mysd))
      
      ## Calculate lambda path (first get lambda_max):
      lambda_max <- max(max(abs(colSums(sx*y_input)))/length(y_input), 1e-8)
      epsilon <- .0001
      K <- 50
      lambdapath <- round(exp(seq(log(lambda_max), log(lambda_max*epsilon), 
                                  length.out = K)), digits = 10)
      cv.glmnet.fit <- glmnet::cv.glmnet(x = x_input, y = y_input,
                                         family = binomial(link = 'logit'),
                                         alpha = 0, lambda  = lambdapath, nfolds = 3)
      tmp_coeffs <- coef(cv.glmnet.fit, s = "lambda.min")
      b <- rep(0, l)
      b[tmp_coeffs@i +1] <- tmp_coeffs@x
    }
  }

  return(list(par = b, explore_flag = explore_flag))
}

decision_grd <- function(t, X, Y, Z, Tox_history, Eff_history, Par.g,
                         eff_tox, eff_dose, tolerate, eps_t, lambda){
  K <- length(eff_tox)
  
  if(sum(is.na(Par.g[t-1,]))){
    b.grd <- rnorm(4, 0, 0.1)
    c.grd <- rnorm(6, 0, 0.1)
  }else{
    b.grd <- Par.g[t-1, 1:4]
    c.grd <- Par.g[t-1, 5:10]
  }
  dmx.y <- cbind(X[1:(t-1),], 
                 Tox_history[1:(t-1),"grd"], X[1:(t-1),]*Tox_history[1:(t-1),"grd"])
  dmx.z <- cbind(X[1:(t-1),], 
                 Eff_history[1:(t-1),"grd"], X[1:(t-1),]*Eff_history[1:(t-1),"grd"],
                 Eff_history[1:(t-1),"grd"]^2, X[1:(t-1),]*Eff_history[1:(t-1),"grd"]^2)
  ## estimate pars for toxicity
  b.grd <- myGLM(dmx.y, Y[1:(t-1),"grd"], b.grd)$par
  ## estimate the pars for efficacy
  c.grd <- myGLM(dmx.z, Z[1:(t-1),"grd"], c.grd)$par
  
  
  # scalarized objective
  x <- as.matrix(X[t,])
  tox_sample <- sapply(1:K, function(j){tox_model_par(x, eff_tox[j],b.grd)$pb})
  eff_sample <- sapply(1:K, function(j){eff_model_par(x, eff_dose[j],c.grd)$pb})
  grd <- min(which.max(eff_sample - lambda*sapply(tox_sample, function(x){max(x-tolerate, 0)})))
    
  # exploration uniformly
  if(runif(1) > 1-eps_t){
    grd <- sample(K, 1)
  }
  
  return(list(b.grd = b.grd, c.grd = c.grd, grd = grd))
}

decision_ode <- function(t, X, Y, Z, Tox_history, Eff_history, Par.o,
                         eff_tox, eff_dose, tolerate, eps_t, lambda){
  K <- length(eff_tox)
  
  if(sum(is.na(Par.o[t-1,]))){
    b.ode <- rnorm(4, 0, 0.1)
    c.ode <- rnorm(6, 0, 0.1)
  }else{
    b.ode <- Par.o[t-1, 1:4]
    c.ode <- Par.o[t-1, 5:10]
  }
  
  dmx.y <- cbind(X[1:(t-1),], 
                 Tox_history[1:(t-1),"ode"], X[1:(t-1),]*Tox_history[1:(t-1),"ode"])
  dmx.z <- cbind(X[1:(t-1),], 
                 Eff_history[1:(t-1),"ode"], X[1:(t-1),]*Eff_history[1:(t-1),"ode"],
                 Eff_history[1:(t-1),"ode"]^2, X[1:(t-1),]*Eff_history[1:(t-1),"ode"]^2)
  ## estimate pars for toxicity
  b.ode <- myGLM(dmx.y, Y[1:(t-1),2], b.ode)$par
  ## estimate the pars for efficacy
  c.ode <- myGLM(dmx.z, Z[1:(t-1),2], c.ode)$par

  x <- as.matrix(X[t,])
  tox_sample <- sapply(1:K, function(j){tox_model_par(x, eff_tox[j],b.ode)$pb})
  eff_sample <- sapply(1:K, function(j){eff_model_par(x, eff_dose[j],c.ode)$pb})
  ode <- min(which.max(eff_sample - lambda*sapply(tox_sample, function(x){max(x-tolerate, 0)})))
  
  # exploration according to optimal design
  if(runif(1) > 1-eps_t){
    info <- sapply(1:K, function(i){
      eval.info(Y[1:(t-1), "ode"], Z[1:(t-1), "ode"], as.matrix(X[1:(t-1),]),
                Tox_history[1:(t-1),"ode"],Eff_history[1:(t-1),"ode"],
                c(b.ode,c.ode), eff_tox, eff_dose,
                x, i, type = "joint") })
    ode <- min(which.min(info))
  }
  
  return(list(b.ode = b.ode, c.ode = c.ode, ode = ode))
}

decision_ind <- function(t, X, Y, Z, Par.i, Choice,
                         eff_tox, eff_dose, tolerate, eps_t, lambda){
  K <- length(eff_tox)
  # update the arm that is pulled last round
  k <- Choice[t-1, "ind"]
  rel_rounds <- which(Choice[,"ind"] == k)
  rel_x <- as.matrix(X[rel_rounds,])
  rel_y <- Y[rel_rounds, "ind"]
  rel_z <- Z[rel_rounds, "ind"]
  ## estimate pars for toxicity
  b.ind <- myGLM(rel_x, rel_y, Par.i[1:2, k, t-1])
  ## estimate the pars for efficacy
  c.ind <- myGLM(rel_x, rel_z, Par.i[3:4, k, t-1])
  par.i <- Par.i[,,t-1]
  par.i[,k] <- c(b.ind$par, c.ind$par)
  
  if(b.ind$explore_flag + c.ind$explore_flag){
    # if the glm.fit is poor, explore more
    ind <- sample(K, 1)
  }else{
    # scalarized objective
    x <- as.matrix(X[t,])
    tox_sample <- sapply(1:K, function(j){eta <- par.i[1:2,j]%*%c(1,x); exp(eta)/(1+exp(eta))})
    eff_sample <- sapply(1:K, function(j){eta <- par.i[3:4,j]%*%c(1,x); exp(eta)/(1+exp(eta))})
    ind <- min(which.max(eff_sample - lambda*sapply(tox_sample, function(x){max(x-tolerate, 0)})))
    
    # exploration uniformly
    if(runif(1) > 1-eps_t){
      ind <- sample(K, 1)
    }
  }
  
  return(list(b.ind = b.ind$par, c.ind = c.ind$par, ind = ind))
}


decision_brk <- function(t, Y, Z, Tox_history, Eff_history, Par.b,
                         rel_rounds, cat.x,
                         eff_tox, eff_dose, tolerate, eps_t, lambda){
  K <- length(eff_tox)
  if(sum(is.na(Par.b[, cat.x, t-1]))){
    b.brk <- rnorm(2, 0, 0.1)
    c.brk <- rnorm(3, 0, 0.1)
  }else{
    b.brk <- Par.b[1:2, cat.x, t-1]
    c.brk <- Par.b[3:5, cat.x, t-1]
  }
  rel_u <- as.matrix(Tox_history[rel_rounds, "brk"])
  rel_v <- cbind(Eff_history[rel_rounds, "brk"], Eff_history[rel_rounds, "brk"]^2)
  rel_y <- Y[rel_rounds, "brk"]
  rel_z <- Z[rel_rounds, "brk"]
  ## estimate pars for toxicity
  b.brk <- myGLM(rel_u, rel_y, b.brk)
  ## estimate the pars for efficacy
  c.brk <- myGLM(rel_v, rel_z, c.brk)
  
  if(b.brk$explore_flag + c.brk$explore_flag){
    brk <- sample(K, 1)
  }else{
    # scalarized objective
    tox_sample <- sapply(1:K, function(j){eta <- b.brk$par%*%c(1,eff_tox[j]); exp(eta)/(1+exp(eta))})
    eff_sample <- sapply(1:K, function(j){eta <- c.brk$par%*%c(1,eff_dose[j], eff_dose[j]^2); exp(eta)/(1+exp(eta))})
    brk <- min(which.max(eff_sample - lambda*sapply(tox_sample, function(x){max(x-tolerate, 0)})))
    
    # exploration uniformly
    if(runif(1) > 1-eps_t){
      brk <- sample(K, 1)
    }
  }
  
  return(list(b.brk = b.brk$par, c.brk = c.brk$par, brk = brk))
}


decision_noc <- function(t, Y, Z, Tox_history, Eff_history, Par.c,
                         eff_tox, eff_dose, tolerate, eps_t, lambda){
  K <- length(eff_tox)
  
  if(sum(is.na(Par.c[t-1,]))){
    b.noc <- rnorm(2, 0, 0.1)
    c.noc <- rnorm(3, 0, 0.1)
  }else{
    b.noc <- Par.c[t-1, 1:2]
    c.noc <- Par.c[t-1, 3:5]
  }
  dmx.y <- as.matrix(Tox_history[1:(t-1),"noc"])
  dmx.z <- cbind(Eff_history[1:(t-1),"noc"], Eff_history[1:(t-1),"noc"]^2)
  ## estimate pars for toxicity
  b.noc <- myGLM(dmx.y, Y[1:(t-1),"noc"], b.noc)$par
  ## estimate the pars for efficacy
  c.noc <- myGLM(dmx.z, Z[1:(t-1),"noc"], c.noc)$par
  
  
  # scalarized objective
  tox_sample <- sapply(1:K, function(j){eta <- b.noc%*%c(1, eff_tox[j]); exp(eta)/(1+exp(eta))})
  eff_sample <- sapply(1:K, function(j){eta <- c.noc%*%c(1, eff_dose[j], eff_dose[j]^2); exp(eta)/(1+exp(eta))})
  noc <- min(which.max(eff_sample - lambda*sapply(tox_sample, function(x){max(x-tolerate, 0)})))
  
  # exploration uniformly
  if(runif(1) > 1-eps_t){
    noc <- sample(K, 1)
  }
  
  return(list(b.noc = b.noc, c.noc = c.noc, noc = noc))
}

decision_noh <- function(t, X, Z, Eff_history, Par.h,
                         eff_dose, tolerate, eps_t, lambda){
  K <- length(eff_dose)
  
  if(sum(is.na(Par.h[t-1,]))){
    c.noh <- rnorm(6, 0, 0.1)
  }else{
    c.noh <- Par.h[t-1, ]
  }
  
  dmx.z <- cbind(X[1:(t-1),], 
                 Eff_history[1:(t-1),"noh"], X[1:(t-1),]*Eff_history[1:(t-1),"noh"],
                 Eff_history[1:(t-1),"noh"]^2, X[1:(t-1),]*Eff_history[1:(t-1),"noh"]^2)
  ## estimate the pars for efficacy
  c.noh <- myGLM(dmx.z, Z[1:(t-1),"noh"], c.noh)$par
  
  
  # scalarized objective
  x <- as.matrix(X[t,])
  eff_sample <- sapply(1:K, function(j){eff_model_par(x, eff_dose[j],c.noh)$pb})
  noh <- min(which.max(eff_sample))
  
  # exploration uniformly
  if(runif(1) > 1-eps_t){
    noh <- sample(K, 1)
  }
  
  return(list(c.noh = c.noh, noh = noh))
}

dose_sim_par <- function(Tt, eff_tox, eff_dose, tolerate, d = 1,
                         eps = 0.9, lambda  = 1, m = 9, L = 3){
  K <- length(eff_tox)
  X <- matrix(nrow = Tt, ncol = d)
  Y  <- matrix(nrow = Tt, ncol = 6)
  Z <- matrix(nrow = Tt, ncol = 6)
  Choice <- matrix(nrow = Tt, ncol = 7)
  Tox_history <- matrix(nrow = Tt, ncol = 6)
  Eff_history <- matrix(nrow = Tt, ncol = 6)
  
  colnames(Y) <- c("grd", "ode", "ind", "brk", "noc", "noh")
  colnames(Z) <- c("grd", "ode", "ind", "brk", "noc", "noh")
  colnames(Tox_history) <- c("grd", "ode", "ind", "brk", "noc", "noh")
  colnames(Eff_history) <- c("grd", "ode", "ind", "brk", "noc", "noh")
  colnames(Choice) <- c("orc", "grd", "ode", "ind", "brk", "noc", "noh")
  # 1st oracle, 2nd varying coefficient + grd,
  # 3rd varying coefficient + optimal design,
  # 4th K separate models, 5th category context
  # 6th ignore context(i.e., it is the 5th with only 1 category)
  # 7th ignore harm
  
  Tox_pb <- matrix(nrow = Tt, ncol = K)
  Eff_pb <- matrix(nrow = Tt, ncol = K)
  Par.g <- matrix(nrow = Tt, ncol = 10)
  Par.o <- matrix(nrow = Tt, ncol = 10)
  Par.i <- array(dim = c(4, K, Tt))
  Par.b <- array(dim = c(5, L, Tt))
  cat.size <- rep(0, L)
  Par.c <- matrix(nrow = Tt, ncol = 5)
  Par.h <- matrix(nrow = Tt, ncol = 6)
 
  for(t in 1:Tt){
    eps_t = min(1, 5*K/eps^2*log(t-1)/(t-1))
    
    x <- matrix(runif(d, 0, 1), ncol = d, byrow = T)
    X[t,] <- x
    
    if(t <= m*K){
      grd <- ifelse(t%%K == 0, K, t%%K)
      ode <- grd
      ind <- grd
      noc <- grd
      noh <- grd
    }else{
      ## greedy
      step_grd <- decision_grd(t, X, Y, Z, Tox_history, Eff_history, Par.g,
                               eff_tox, eff_dose, tolerate, eps_t, lambda)
      grd <- step_grd$grd
      Par.g[t,] <- c(step_grd$b.grd, step_grd$c.grd)
      
      ## Optimal design  
      step_ode <- decision_ode(t, X, Y, Z, Tox_history, Eff_history, Par.o,
                               eff_tox, eff_dose, tolerate, eps_t, lambda)
      ode <- step_ode$ode
      Par.o[t,] <- c(step_ode$b.ode, step_ode$c.ode)
 
      ## learn K separate models for each arm
      if(t == m*K+1){
        # initial estimate
        for(k in 1:K){
          rel_rounds <- which(Choice[1:(t-1),"ind"] == k)
          rel_x <- as.matrix(X[rel_rounds,])
          rel_y <- Y[rel_rounds, "ind"]
          rel_z <- Z[rel_rounds, "ind"]
          b.ind <- myGLM(rel_x, rel_y, start = rnorm(2, 0, 0.1))
          c.ind <- myGLM(rel_x, rel_z, start = rnorm(2, 0, 0.1))
          Par.i[,k,t] <- c(b.ind$par, c.ind$par)
        }
        if(b.ind$explore_flag + c.ind$explore_flag){
          ind <- sample(K, 1)
        }else{
          tox_sample <- sapply(1:K, function(j){exp(Par.i[1:2,j,t]%*%c(1,x))/(1+exp(Par.i[1:2,j,t]%*%c(1,x)))})
          eff_sample <- sapply(1:K, function(j){exp(Par.i[3:4,j,t]%*%c(1,x))/(1+exp(Par.i[3:4,j,t]%*%c(1,x)))})
          ind <- min(which.max(eff_sample - lambda*sapply(tox_sample, function(x){max(x-tolerate, 0)})))
          if(runif(1) > 1-eps_t){
            ind <- sample(K, 1)
          }
        }
      }else{
        step_ind <- decision_ind(t, X, Y, Z, Par.i, Choice,
                                 eff_tox, eff_dose, tolerate, eps_t, lambda)
        ind <- step_ind$ind
        Par.i[,,t] <- Par.i[,,t-1]
        k <- Choice[t-1, "ind"]
        Par.i[,k, t] <- c(step_ind$b.ind, step_ind$c.ind)
      }
      
      ## ignore the context
      step_noc <- decision_noc(t, Y, Z, Tox_history, Eff_history, Par.c,
                               eff_tox, eff_dose, tolerate, eps_t, lambda)
      noc <- step_noc$noc
      Par.c[t,] <- c(step_noc$b.noc, step_noc$c.noc)
      
      ## ignore the harm
      step_noh <- decision_noh(t, X, Z, Eff_history, Par.h,
                               eff_dose, tolerate, eps_t, lambda)
      noh <- step_noh$noh
      Par.h[t,] <- c(step_noh$c.noh)
        
    }
    
    ## categorize the context into L brackets
    ## and learn a context-free model for each category
    cat.x <- ceiling(x*L)[1,1]
    cat.n <- cat.size[cat.x]
    # this initialization is different from other methods
    if(cat.n < floor(m/L)*K){
      brk <- cat.n%%K+1
    }else{
      # find the rounds in the same bracket
      rel_rounds <- which(ceiling(X[1:(t-1),]*L) == cat.x)
      step_brk <- decision_brk(t, Y, Z, Tox_history, Eff_history, Par.b,
                               rel_rounds, cat.x,
                               eff_tox, eff_dose, tolerate, eps_t, lambda)
      Par.b[,,t] <- Par.b[,,t-1]
      Par.b[,cat.x, t] <- c(step_brk$b.brk, step_brk$c.brk)
      brk <- step_brk$brk
    }
    cat.size[cat.x] = cat.size[cat.x] + 1
    
    
    tox_true <- sapply(1:K, function(i){tox_model_true(x, eff_tox[i])})
    eff_true <- sapply(1:K, function(i){eff_model_true(x, eff_dose[i])})
    choice_best <- min(which.max(eff_true - lambda*sapply(tox_true, function(x){max(x-tolerate, 0)})))
    # choice_safe <- which(tox_true <= tolerate)
    # if(length(choice_safe) == 0) choice_safe <- 1
    # choice_best <- min(which.max(eff_true[choice_safe]))
    Choice[t,] <- c(choice_best, grd, ode, ind, brk, noc, noh)
    Tox_history[t,] <- eff_tox[c(grd, ode, ind, brk, noc, noh)]
    Eff_history[t,] <- eff_dose[c(grd, ode, ind, brk, noc, noh)]
    Tox_pb[t,] <- tox_true
    Eff_pb[t,] <- eff_true
    
    y <- sapply(c(2:7), function(i){
      runif(1) < Tox_pb[t, Choice[t, i]]
    })
    
    z <- sapply(c(2:7), function(i){
      runif(1) < Eff_pb[t, Choice[t, i]]
    })
    
    
   Y[t,] <- y
   Z[t,] <- z
  }
  
  return(list(Y = Y, Z= Z, X = X, Choice = Choice, 
              Tox_pb = Tox_pb, Eff_pb = Eff_pb,
              Par.g = Par.g, Par.o = Par.o, Par.i = Par.i,
              Par.b = Par.b, Par.c = Par.c, Par.h = Par.h))
}




eff_tox <- seq(0.1, 0.7, by =0.1)
eff_dose <- seq(0.1, 0.7, by =0.1)




tr <- 100
lambda  <- 1 # change value to change penalty
set.seed(215550 *lambda)
# time.ad <- Sys.time()
# res <- dose_sim_par(Tt=5000, eff_tox, eff_dose,
#                     tolerate = 0.33, d = 1,
#                     eps = 1.75)
# time.ad <- Sys.time() - time.ad

res1 <- vector("list", tr)
for(i in 1:tr){
  res <- dose_sim_par(Tt=5000, eff_tox, eff_dose,
                    tolerate = 0.33, d = 1, 
                     eps = 1.75, lambda = lambda)
  res1[[i]] <- res
}

filename <- paste0("binary_example", "_lam_", lambda, ".RData")
save.image(file = filename)


