library("rlecuyer")
library("MASS")
library("hdi")
library("purrr")
library("tidyverse")
library("ggplot2")
library("PSweight")
library("nleqslv")
library("nloptr")
library("R.utils")
library("latex2exp")
library("grf")



################### (0) Data generation process ###################
## Generate covariates
# Input: 
#     n: sample size
#     m: number of subgroups
#     sampling_prob: sampling probability of each subgroup
# Output:
#     X: simulated covariates
GenX <- function(n,m,sampling_prob){
  
  X <- sample(1:m, size = n, replace = TRUE, prob = sampling_prob)
  
  return(X)
}


## Generate subgroup membership
# Input: 
#     X: covariates
#.    m: number of subgroups
# Output:
#     S: generated membership
GenS <- function(X,m){
  
  S <- matrix(NA,nrow = length(X),ncol = m)
  
  for(j in 1:m){
    S[,j] <- (X ==j)
  }
  
  colnames(S) <- LETTERS[seq(1:m)]
  
  return(S)
}


## Generate outcomes
# Input: 
#     n: sample size
#     X: covariates
#     Tr: treatment 
#     S: subgroup membership
#     m: number of subgroups
#     mu1_vec: treatment arm mean
#     mu0_vec: control arm mean
#     sd1_vec: treatment arm standard deviation
#     sd0_vec: control arm standard deviation
# Output:
#     Y: simulated outcome
GenY <- function(n,Tr,S,mu1_vec,mu0_vec,sd1_vec,sd0_vec){
  
  Y <- NULL
  
  for(i in 1:n){
    
    # treatment arm
    if(Tr[i]==1){
      idx <- which(S[i,]==1)
      Y[i] <- rnorm(1,mu1_vec[idx],sd1_vec[idx])
    }
    
    # control arm
    if(Tr[i]==0){
      idx <- which(S[i,]==1)
      Y[i] <- rnorm(1,mu0_vec[idx],sd0_vec[idx])
    }
    
  }
  
  return(Y)
}



################### (1) Compute optimal subgroup treatment allocation ###################
## Our design strategy
# Input: 
#     tau_old: estimated subgroup treatment effects
#     sd_old.t: treatment arm standard deviation
#     sd_old.c: control arm standard deviation
#     th: cost constraint threshold
#     S: subgroup membership
#     n: sample size
#     m: number of subgroups
# Output:
#     e_star: computed optimal subgroup treatment allocation
SubAlloc <- function(tau_old,sd_old.t,sd_old.c,th,S,n,m){
  
  ranking <- order(tau_old,decreasing = TRUE)
  
  # Reorder by ATE ranking
  sigma1_vec <- sd_old.t[ranking][1:m]
  sigma0_vec <- sd_old.c[ranking][1:m]
  
  tau_vec <- tau_old[ranking][1:m]
  
  p <- colSums(S)/n
  
  p_ranked <- p[ranking][1:m]
  
  # objective function
  eval_f0 <- function(x){
    res <- x[m+1]
    return(res)
  }
  
  # constraint functions
  eval_g0 <- function(x) {
    
    var1 <- (sigma1_vec[1])^2/(p_ranked[1]*x[1]) + (sigma0_vec[1])^2/(p_ranked[1]*(1-x[1]))
    
    constr_vec <- NULL
    
    for(j in 2:m ){
      var_j <- (sigma1_vec[j])^2/(p_ranked[j]*x[j]) + (sigma0_vec[j])^2/((1-p_ranked[j])*x[j])
      
      constr_j <- (tau_vec[j]-tau_vec[1])^2/(2*(var1+var_j)-x[m+1])
      
      constr_vec <- c(constr_vec,constr_j)
    }
    
    x_sum <- 0
    for(j in 1:m){
      x_sum <- x_sum+p_ranked[j]*x[j]
    }
    
    constr_vec <- c(constr_vec,0.5-x_sum)
    
    return( constr_vec )
  }
  
  res1 <- nloptr(x0=c(rep(0.1,m),0.0001),
                 eval_f=eval_f0,
                 lb = c(rep(0.1,m),-20),
                 ub = c(rep(0.9,m),9000),
                 eval_g_ineq = eval_g0,
                 opts = list("algorithm"="NLOPT_LN_COBYLA",
                             "xtol_rel"=1.0e-8,
                             "maxtime"=60))
  # optimal e*
  e_star <- res1$solution[1:m]
  
  names(e_star) <- colnames(S)[ranking]
  
  return(e_star)
}


## Multi-armed bandit algorithm (Epsilon Greedy Algorithm)
# Input: 
#     epsilon: a parameter that plays a crucial role in balancing exploration and exploitation
#     m: number of subgroups
#     sampling_prob: sampling probability of each subgroup
#     n: sample size
#     mu1_vec: treatment arm mean
#     mu0_vec: control arm mean
#     sd1_vec: treatment arm standard deviation
#     sd0_vec: control arm standard deviation
# Output:
#     e_star: computed optimal subgroup treatment allocation
SubAlloc_eg <- function(epsilon,m,sampling_prob,n,mu1_vec,mu0_vec,sd1_vec,sd0_vec){
  
  nk <- round(sampling_prob * n)
  
  # Choose an action
  choose_action <- function(Q_values) {
    if (runif(1) < epsilon) {
      # Explore: Choose a random arm
      return(sample(0:1, 1))
    } else {
      # Exploit: Choose the arm with the highest Q-value
      return(which.max(Q_values) - 1)
    }
  }
  
  Gen <- function(chosen_arm,s,mu1_vec,mu0_vec,sd1_vec,sd0_vec){
    # treatment arm
    if(chosen_arm==1){
      Y <- rnorm(1,mu1_vec[s],sd1_vec[s])
    }
    # control arm
    if(chosen_arm==0){
      Y <- rnorm(1,mu0_vec[s],sd0_vec[s])
    }
    return(Y)
  }
  
  # Get the reward for the chosen arm
  get_reward <- function(outcome, arm){
    e <- sum(arm == 1) / length(arm)
    
    if ((sum(arm == 0) == 0) || (sum(arm == 0) == 1)){
      var0 <- 0
    }else{
      var0 <- var(outcome[arm == 0]) / (1 - e)
    }
    
    if ((sum(arm == 1) == 0) || (sum(arm == 1) == 1)){
      var1 <- 0
    }else{
      var1 <- var(outcome[arm == 1]) / e
    }
    
    reward <- -(var0 + var1)
    
    return(reward)
  }
  
  # Update Q-values based on rewards received
  update_Q_values <- function(arm, Q_values, reward){
    Q_values[last(arm)+1] <- Q_values[last(arm)+1] + (reward - Q_values[last(arm)+1]) / sum(arm == last(arm))
    return(Q_values)
  }
  
  # Repeat for a fixed number of time steps (horizon)
  e_star <- numeric(m)
  for (i in 1:m){
    arm <- NULL
    outcome <- NULL
    # Initialize Q-values for each arm
    Q_values <- rep(0, 2)
    for (t in 1:nk[i]) {
      arm[t] <- choose_action(Q_values)
      outcome[t] <- Gen(arm[t],i,mu1_vec,mu0_vec,sd1_vec,sd0_vec)
      reward <- get_reward(outcome, arm)
      Q_values <- update_Q_values(arm, Q_values, reward)
    }
    e_star[i] <- sum(arm == 1) / length(arm)
  }
  
  names(e_star) <- LETTERS[seq(1:m)]
  
  return(e_star)
}


## Multi-armed bandit algorithm (Upper Confidence Bound 1 Algorithm)
# Input: 
#     m: number of subgroups
#     sampling_prob: sampling probability of each subgroup
#     n: sample size
#     mu1_vec: treatment arm mean
#     mu0_vec: control arm mean
#     sd1_vec: treatment arm standard deviation
#     sd0_vec: control arm standard deviation
# Output:
#     e_star: computed optimal subgroup treatment allocation
SubAlloc_ucb1 <- function(m,sampling_prob,n,mu1_vec,mu0_vec,sd1_vec,sd0_vec){
  
  nk <- round(sampling_prob * n)
  
  # Choose an action
  choose_action <- function(t) {
    if (t <= 2) {
      # Play each arm once initially
      return(t - 1)
    }else{
      ucb_values <- total_rewards / num_pulls + sqrt(2 * log(t) / num_pulls)
      return(which.max(ucb_values) - 1)
    }
  }
  
  Gen <- function(chosen_arm,s,mu1_vec,mu0_vec,sd1_vec,sd0_vec){
    # treatment arm
    if(chosen_arm==1){
      Y <- rnorm(1,mu1_vec[s],sd1_vec[s])
    }
    # control arm
    if(chosen_arm==0){
      Y <- rnorm(1,mu0_vec[s],sd0_vec[s])
    }
    return(Y)
  }
  
  # Get the reward for the chosen arm
  get_reward <- function(outcome, arm){
    e <- sum(arm == 1) / length(arm)
    
    if ((sum(arm == 0) == 0) || (sum(arm == 0) == 1)){
      var0 <- 0
    }else{
      var0 <- var(outcome[arm == 0]) / (1 - e)
    }
    
    if ((sum(arm == 1) == 0) || (sum(arm == 1) == 1)){
      var1 <- 0
    }else{
      var1 <- var(outcome[arm == 1]) / e
    }
    
    reward <- -(var0 + var1)
    
    return(reward)
  }
  
  # Repeat for a fixed number of time steps (horizon)
  e_star <- numeric(m)
  for (i in 1:m){
    arm <- NULL
    outcome <- NULL
    num_pulls <- numeric(2)
    # Initialize Q-values for each arm
    total_rewards <- numeric(2)
    for (t in 1:nk[i]) {
      arm[t] <- choose_action(t)
      outcome[t] <- Gen(arm[t],i,mu1_vec,mu0_vec,sd1_vec,sd0_vec)
      reward <- get_reward(outcome, arm)
      total_rewards[arm[t]+1] <- total_rewards[arm[t]+1] + reward
      num_pulls[arm[t]+1] <- num_pulls[arm[t]+1] + 1
    }
    e_star[i] <- sum(arm == 1) / length(arm)
  }
  
  names(e_star) <- LETTERS[seq(1:m)]
  
  return(e_star)
}



################### (2) Synthetic adaptive experiments ###################
# Input: 
#     n1: first-stage sample size
#     n: number of subjects enrolled in each of the subsequent stage
#     c1: cost constraint threshold
#     m: number of subgroups
#     K: the Kth smallest parameters of interest (the biggest: K = m)
#     true_p: population subgroup proportions
#     mu1_vec: treatment arm mean
#     mu0_vec: control arm mean
#     sd1_vec: treatment arm standard deviation
#     sd0_vec: control arm standard deviation
#     num_stage: number of experimental stages
#     B: times of resampling
#     truetie: actual tie set
#     mab: whether uses multi-armed bandit algorithm when identifying best subgroup
#          (0 refers to complete randomization; 1 refers to epsilon greedy algorithm; 2 refers to UCB1 algorithm)
#     tie: whether includes tie set identification
#     merge: whether merge subgroups identified as tie set
#     method.select: hyperparameter selection method (1 refers to double bootstrap; 2 refers to single bootstrap)
#     method.identify: dynamic identification of the best subgroups method (1 refers to naive bootstrap; 2 refers to separate bootstrap)
# Output:
#     tau_opt: subgroup treatment effects under the proposed design after all stages
#     tieset: tie set selected after each stage under proposed design
#     tietau: estimated treatment effect in the best subgroup after each stage under the proposed design
simMultiStageRAR <- function(n1,n,c1,m,K,true_p,mu1_vec,mu0_vec,sd1_vec,sd0_vec,num_stage,B,truetie,mab=FALSE,tie=TRUE,merge=TRUE,method.select=NULL,method.identify=NULL){
  
  tieset <- numeric(1+num_stage)
  tietau <- numeric(1+num_stage)
  tau_opt_list <- list()
  
  ################### Stage 1 ###################
  ## Data generation
  # Generate covariates
  X_1 <- GenX(n1,m,true_p)
  
  # Generate subgroup memberships
  S_1 <- GenS(X_1,m)
  
  # Assign treatment randomly
  T_1 <- rbinom(n1,1,0.5)
  
  # Generate outcomes
  Y_1 <-  GenY(n1,T_1,S_1,mu1_vec,mu0_vec,sd1_vec,sd0_vec)
  
  # Randomly assign treatments
  e_1 <- rep(1/2, m)
  
  ## Estimation
  # Estimated subgroup proportions
  p_1 <- colSums(S_1)/n1
  
  # Estimate propensity scores
  e_1.hat <- NULL
  for(k in 1:ncol(S_1)){
    e_1.hat[k] <- sum(T_1[S_1[,k]])/(sum(S_1[,k]))
  }
  
  # Estimate subgroup ATEs
  # tau_1 <- sd_1.t <- sd_1.c <- NULL
  # for (j in 1:ncol(S_1)){
  #   dat1 <- as.data.frame(cbind(Y_1[S_1[,j]],T_1[S_1[,j]],X_1[S_1[,j]]))
  #   names(dat1) <- c("Y","Tr","X")
  #   
  #   tau_1[j] <-  mean(dat1$Y*dat1$Tr/e_1.hat[j] - dat1$Y*(1-dat1$Tr)/(1-e_1.hat[j]))
  #   sd_1.t[j]<- sd(dat1$Y[dat1$Tr==1])
  #   sd_1.c[j]<- sd(dat1$Y[dat1$Tr==0])
  # }
  S_1_num <- matrix(as.numeric(S_1), nrow = nrow(S_1))
  causal_forest_model <- causal_forest(S_1_num, Y_1, T_1)
  individual_treatment_effects <- predict(causal_forest_model)$predictions
  
  tau_1 <- sd_1.t <- sd_1.c <- NULL
  for (j in 1:ncol(S_1)){
    tau_1[j] <- mean(individual_treatment_effects[S_1[,j]])
    sd_1.t[j]<- sd(individual_treatment_effects[S_1[,j] & T_1==1])
    sd_1.c[j]<- sd(individual_treatment_effects[S_1[,j] & T_1==0])
  }
  
  # Name subgroup ATEs
  names(tau_1) <- LETTERS[seq(1:m)]
  # Name subgroup SDs
  names(sd_1.t) <- names(sd_1.c) <- LETTERS[seq(1:m)]
  
  tau_old <- tau_1
  sd_old.t <- sd_1.t
  sd_old.c <- sd_1.c
  S_old <- S_1
  n_old <- n1
  T_old <- T_1
  X_old <- X_1
  Y_old <- Y_1
  e_1.hat_old <- e_1.hat
  # p_old <- p_1
  # var_opt <- 1/p_old*(sd_old.t^2/e_1.hat_old + sd_old.c^2/(1-e_1.hat_old))
  tau_opt <- tau_old
  tau_opt_list[[1]] <- tau_opt
  
  
  if (tie && (merge || ((!merge) && (num_stage == 0)))){
    ## Identifiy tie set based on Stage 1 data
    # Generate random samples
    newob2 <- function(the,s,nk){
      data <- mvrnorm(n = 1,the,s,tol = 2)
      return(data)
    }
    
    ## Find the tie set for \hat{\tau}_{\hat{1}} and find debiased estimates based on the tie set
    # Input:
    #     thetas: \hat{\tau}
    #     s: the estimated variances of \hat{\tau}
    #     nk: the size of each subgroup
    #     sigm: the estimated covariance matrix
    #     n1: first-stage sample size
    #     m: the number of subgroups
    #     K: the Kth smallest parameters of interest (the biggest: K = m)
    #     cl: c_L
    #     cr: c_R
    #     B: times of resampling
    #     truetie: actual tie set
    # Output:
    #     result: a vecter including correction selection probability, 
    #             and three kinds of debiased estimates based on the tie set
    #     pattern: tie set selected after each stage under proposed design
    mm2 <- function(thetas,s,nk,sigm,n1,m,K,cl,cr,B,truetie){
      
      them = sort(thetas)[K]
      epsiss = numeric(B)
      tilde.beta = numeric(B)
      w = which(thetas==them); d=1/4; tn = (s[w])^d
      tie <- matrix(NA, nrow = B, ncol = m)
      bl <- n1^{-d}*cl*tn
      br <- n1^{-d}*cr*tn
      
      for (j in 1:B) {
        epsi = newob2(thetas,sigm)
        e = epsi-sort(epsi)[K]
        wkm = numeric(m)
        temp = ((e<=br)+(e>=-bl))==2
        wkm[temp] = 1
        tie[j,] = wkm
        epsiss[j] = sum(wkm*epsi)/sum(wkm)
        tilde.beta[j] = sum(wkm*thetas)/sum(wkm)
      }
      
      # convert tie dataframe rows to character strings
      tieset_strings <- apply(tie, 1, paste, collapse = "")
      # Count the occurrence of each tie set pattern
      string_counts <- table(tieset_strings)
      ratio <- string_counts / B
      
      epsiss.sort = sort(epsiss)
      lower = epsiss.sort[0.025*B] # estimate the 95% confidence interval lower bound
      upper = epsiss.sort[0.975*B] # estimate the 95% confidence interval upper bound
      width = upper-lower
      esti = mean(epsiss.sort)
      esti2 = median(epsiss.sort) # estimate the mth smallest parameter
      mtbeta = median(tilde.beta)
      
      if (sum(names(ratio) == truetie) == 0){
        correct <- 0
      }else{
        correct <- as.numeric(ratio[names(ratio) == truetie])
      }
      
      pattern <- as.numeric(strsplit(names(ratio)[ratio == max(ratio)],split='')[[1]])
      
      return(list(result = data.frame(paste(pattern, collapse = ""), correct, esti, esti2, mtbeta), pattern = pattern))
    }
    
    met1 <- function(thetahh,s,nk,sigm,n1,m,K,cl,cr,R,B){
      
      r = mean(s)/var(thetahh); sig = sqrt(sum(s*nk)/m); d = 1/4
      temp = r*(sig/sqrt(mean(s)))^0.1; tri = min(c(1,temp))
      thetah = tri*mean(thetahh)+(1-tri)*thetahh
      them = sort(thetah)[K]
      b = numeric(B)
      
      for (i in 1:B) {
        thetash = newob2(thetah,sigm)
        temp = sort(thetash)[K]; w = which(thetash==temp)
        
        for (j in 1:R) {
          epsi = newob2(thetash,sigm); tn = (s[w])^d
          bl <- n1^{-d}*cl*tn
          br <- n1^{-d}*cr*tn
          e = epsi-sort(epsi)[K]
          wkm = numeric(m)
          temp = ((e<=br)+(e>=-bl))==2
          wkm[temp] = 1
          epsiss = sum(wkm*epsi)/sum(wkm)
          b[i] = b[i]+ifelse(them-epsiss>=0,1,0)
        }
        
        b[i] = b[i]/R
      }
      
      l = sum((sort(b)-seq(1,B)/(B+1))^2)/B
      return(l)
    }
    
    met2 <- function(thetahh,s,nk,sigm,n1,m,K,cl,cr,B){
      
      r = mean(s)/var(thetahh); sig = sqrt(sum(s*nk)/m); d = 1/4
      temp = r*(sig/sqrt(mean(s)))^0.1; tri = min(c(1,temp))
      thetah = tri*mean(thetahh)+(1-tri)*thetahh
      them = sort(thetah)[K]
      b = numeric(B)
      
      for (i in 1:B) {
        epsi = newob2(thetah,sigm)
        temp = sort(epsi)[K]; w = which(epsi==temp)
        tn = (s[w])^d
        bl <- n1^{-d}*cl*tn
        br <- n1^{-d}*cr*tn
        e = epsi-sort(epsi)[K]
        wkm = numeric(m)
        temp = ((e<=br)+(e>=-bl))==2
        wkm[temp] = 1
        epsiss = sum(wkm*epsi)/sum(wkm)
        b[i] = b[i]+ifelse(them-epsiss>=0,1,0)
      }
      
      l = (sum(((sort(b)==1)-mean(sort(b)==1))^2)/B+sum(((sort(b)==0)-mean(sort(b)==0))^2)/B)/2
      return(l)
    }
    
    gam1 = function(B){
      l = numeric(20000)
      for (i in 1:20000) {
        l[i] = sum((sort(runif(B))-(1:B)/(B+1))^2)/B
      }
      return(quantile(l,0.975))
    }
    
    gam2 = function(B){
      l = numeric(20000)
      for (i in 1:20000) {
        b = rbinom(B, size = 1, prob = 0.5)
        l[i] = (sum(((sort(b)==1)-mean(sort(b)==1))^2)/B+sum(((sort(b)==0)-mean(sort(b)==0))^2)/B)/2
      }
      return(quantile(l,0.975))
    }
    
    ## Function for finding the lower and upper tuning parameters
    # Input:
    #     cl: choices of c_L
    #     cr: choices of c_R
    #     thetah: \hat{\tau}
    #     s: the estimated variances of \hat{\tau}
    #     nk: the size of each subgroup
    #     sigm: the estimated covariance matrix
    #     n1: first-stage sample size
    #     m: the number of subgroups
    #     K: the Kth smallest parameters of interest (the biggest: K = m)
    #     gam: threshold
    #     R, B: times of resampling
    # Output:
    #     return the best selected (cl,cr)
    tuningf1 <- function(cl,cr,thetah,s,nk,sigm,n1,m,K,gam,R,B){
      
      l1 = length(cr); l2 = length(cl); d=1/4
      l = matrix(0,nrow = l2,ncol = l1)
      
      for (n1 in 1:l2) {
        c1=cl[n1]
        
        for (n2 in 1:l1) {
          c2=cr[n2]
          l[n1,n2]=met1(thetah,s,nk,sigm,n1,m,K,c1,c2,R,B)
        }
        
      }
      
      if (!any(l<=gam)) {
        mi = min(l); temp = which(l==mi,arr.ind = T)[1,]
        c1 = temp[1]; c2 = temp[2]
        
        return(list(cl = cl[c1],cr = cr[c2]))
      }
      
      temp = which(l<=gam,arr.ind = T)
      c1 = mean(cl[temp[,1]]); c2 = mean(cr[temp[,2]])
      
      return(list(cl = c1,cr = c2))
    }
    
    tuningf2 <- function(cl,cr,thetah,s,nk,sigm,n1,m,K,gam,B){
      
      l1 = length(cr); l2 = length(cl); d=1/4
      l = matrix(0,nrow = l2,ncol = l1)
      
      for (n1 in 1:l2) {
        c1=cl[n1]
        
        for (n2 in 1:l1) {
          c2=cr[n2]
          l[n1,n2]=met2(thetah,s,nk,sigm,n1,m,K,c1,c2,B)
        }
        
      }
      
      if (!any(l<=gam)) {
        mi = min(l); temp = which(l==mi,arr.ind = T)[1,]
        c1 = temp[1]; c2 = temp[2]
        
        return(list(cl = cl[c1],cr = cr[c2]))
      }
      
      temp = which(l<=gam,arr.ind = T)
      c1 = mean(cl[temp[,1]]); c2 = mean(cr[temp[,2]])
      
      return(list(cl = c1,cr = c2))
    }
    
    ## simulation
    # Input:
    #     n1: first-stage sample size
    #     m: the number of subgroups
    #     K: the Kth smallest parameters of interest (the biggest: K = m)
    #     true_p: population subgroup proportions
    #     var_old: estimated variances
    #     B: times of resampling
    #     tau_old: estimated subgroup treatment effects
    #     truetie: actual tie set
    #     method.select: hyperparameter selection method (1 refers to double bootstrap; 2 refers to single bootstrap)
    # Output:
    #     pattern: tie set selected after each stage under proposed design
    sec.fun.1 <- function(n1,m,K,true_p,var_old,B,tau_old,truetie,method.select) {
      
      nk <- round(true_p*n1) # the size of each subgroup
      
      ga1 = gam1(5)
      ga2 = gam2(5)
      
      s <- var_old
      sigm <- diag(s)/n1
      theta <- tau_old
      s <- s/n1
      thetas <- c(theta)
      
      cl <- seq(19, 21, length.out = 5) 
      cr <- seq(19, 21, length.out = 5)
      
      if (method.select == 1){
        # double-bootstrap
        temp1 = tuningf1(cl,cr,thetas,s,nk,sigm,n1,m,K,ga1,200,40)
        CL1 = temp1$cl; CR1 = temp1$cr
        temp2 <- mm2(thetas,s,nk,sigm,n1,m,K,CL1,CR1,B,truetie)
        result <- temp2$result
        pattern <- temp2$pattern
      }else if (method.select == 2){
        # single-bootstrap
        temp3 = tuningf2(cl,cr,thetas,s,nk,sigm,n1,m,K,ga2,200)
        CL2 = temp3$cl; CR2 = temp3$cr
        temp4 <- mm2(thetas,s,nk,sigm,n1,m,K,CL2,CR2,B,truetie)
        result <- temp4$result
        pattern <- temp4$pattern
      }
      
      return(list(result = result, pattern = pattern))
    }
    
    temp.1 <- sec.fun.1(n1,m,K,true_p,var_opt,B,tau_old,truetie,method.select)
    result <- temp.1$result
    
    if (merge){
      pattern <- temp.1$pattern
      ## Merge subgroups who are ties
      m <- m - sum(pattern==1) +1
      a <- which(pattern==1)
      a0 <- which(pattern==0)
      
      mu1_vec_new <- rep(NA,m)
      mu1_vec_new[a[1]] <- sum(mu1_vec[a]*true_p[a])/sum(true_p[a])
      mu1_vec_new[-a[1]] <- mu1_vec[-a]
      mu1_vec <- mu1_vec_new
      
      mu0_vec_new <- rep(NA,m)
      mu0_vec_new[a[1]] <- sum(mu0_vec[a]*true_p[a])/sum(true_p[a])
      mu0_vec_new[-a[1]] <- mu0_vec[-a]
      mu0_vec <- mu0_vec_new
      
      sd1_vec_new <- rep(NA,m)
      sd1_vec_new[a[1]] <- sqrt(sum(((sd1_vec[a])^2)*((true_p[a])^2)))/sum(true_p[a])
      sd1_vec_new[-a[1]] <- sd1_vec[-a]
      sd1_vec <- sd1_vec_new
      
      sd0_vec_new <- rep(NA,m)
      sd0_vec_new[a[1]] <- sqrt(sum(((sd0_vec[a])^2)*((true_p[a])^2)))/sum(true_p[a])
      sd0_vec_new[-a[1]] <- sd0_vec[-a]
      sd0_vec <- sd0_vec_new
      
      name_new <- rep(NA,m)
      name_new[a[1]] <- paste(sort(names(true_p)[a]), collapse = "")
      name_new[-a[1]] <- names(true_p)[a0]
      tieset[1] <- paste(name_new, collapse = ",")
      
      true_p_new <- rep(NA,m)
      true_p_new[a[1]] <- sum(true_p[a])
      true_p_new[-a[1]] <- true_p[-a]
      true_p <- true_p_new
      names(true_p) <- name_new
      
      
      X_old[X_old %in% a] <- a[1]
      for (j in 1:length(unique(X_old))){
        X_old[X_old==sort(unique(X_old))[j]] <- j
      }
      
      S_old <- matrix(NA,nrow = length(X_old),ncol = m)
      for(j in 1:m){
        S_old[,j] <- (X_old == j)
      }
      colnames(S_old) <- LETTERS[seq(1:m)]
      
      # update ATE and standard deviation
      p_old <- e_1.hat_old <-tau_old <- sd_old.t <- sd_old.c <- NULL
      
      p_old <- colSums(S_old)/n1
      
      for(k in 1:ncol(S_old)){
        e_1.hat_old[k] <- sum(T_old[S_old[,k]])/(sum(S_old[,k]))
      }
      
      for (j in 1:ncol(S_old)){
        
        dat1 <- as.data.frame(cbind(Y_old[S_old[,j]],T_old[S_old[,j]],X_old[S_old[,j]]))
        
        names(dat1) <- c("Y","Tr","X")
        
        ps <- sum(dat1$Tr)/nrow(dat1)
        
        tau_old[j] <-  mean(dat1$Y*dat1$Tr/ e_1.hat_old[j] - dat1$Y*(1-dat1$Tr)/(1-e_1.hat_old[j]))
        
        sd_old.t[j]<- sd(dat1$Y[dat1$Tr==1])
        
        sd_old.c[j]<- sd(dat1$Y[dat1$Tr==0])
      }
      
      names(tau_old) <- LETTERS[seq(1:m)]
      
      names(sd_old.t) <- names(sd_old.c)<- LETTERS[seq(1:m)]
      
      tau_opt <- tau_old
      var_opt <- 1/p_old*(sd_old.t^2/e_1.hat_old + sd_old.c^2/(1-e_1.hat_old))
      tietau[1] <- tau_opt[a[1]]
    }
    
  }
  
  
  ################### Stage t ###################
  if (num_stage != 0){
    ## Revise treatment allocation for Stage 2,3,..,T
    for(i in 1:num_stage){ 
      
      # Generate covariates
      X_i <- GenX(n,m,true_p)
      
      # Generate subgroup memberships
      S_i <- GenS(X_i,m)
      
      # Update old information
      S_new <- rbind(S_old,S_i)
      
      # Update sample size
      n_new <- n_old + n
      
      group_name <- colnames(S_i)
      
      # Use previous stage compute treatment allocation for the new subject
      if (mab == FALSE){
        e_opt <- SubAlloc(tau_old,sd_old.t,sd_old.c,c1,S_new,n_new,m)
      }
      
      # Complete randomization
      if (mab == 0){
        e_opt <- rep(1/2, m)
        names(e_opt) <- LETTERS[seq(1:m)]
      }
      
      # Multi-armed bandit algorithm
      if (mab == 1){
        e_opt <- SubAlloc_eg(0.1,m,true_p,n_new,mu1_vec,mu0_vec,sd1_vec,sd0_vec)
      }
      if (mab == 2){
        e_opt <- SubAlloc_ucb1(m,true_p,n_new,mu1_vec,mu0_vec,sd1_vec,sd0_vec)
      }
      
      e_opt <- e_opt[group_name] 
      
      # current allocation
      e_current <-   e_1.hat_old
      
      p_new <-  sum(S_new[,group_name])/n_new
      
      # we do calibration because we want: n_1*e1.hat + \sum_{t=2}^T n_t*e_t.hat = e_opt
      e_calibrated <- (e_opt*colSums(S_new[,group_name]) -  e_current*colSums(S_old[,group_name]))/colSums(S_i[,group_name])
      
      e_calibrated_2 <- NULL
      
      for(k in 1:m){
        if(e_calibrated[k]>=1){
          e_calibrated_2[k] <- 0.99
        }else if(e_calibrated[k]<=0){
          e_calibrated_2[k] <- 0.01
        }else{
          e_calibrated_2[k] <- e_calibrated[k]
        }
      }
      
      T_i <- NULL
      for(j in 1:n){
        idx <- which(S_i[j,])
        T_i[j] <- rbinom(1,1,e_calibrated_2[idx])
      }
      
      # Generate outcomes
      Y_i <-  GenY(n,T_i,S_i,mu1_vec,mu0_vec,sd1_vec,sd0_vec)
      
      # Update old information
      T_old <- c(T_old, T_i)
      
      Y_old <- c(Y_old, Y_i)
      
      S_old <- S_new
      
      X_old <- c(X_old, X_i)
      
      n_old <- n_new
      
      # Update ATE and standard deviation
      tau_old <- sd_old.t <- sd_old.c <- NULL
      
      for(k in 1:ncol(S_old)){
        e_1.hat_old[k] <- sum(T_old[S_old[,k]])/(sum(S_old[,k]))
      }
      
      # for (j in 1:ncol(S_old)){
      #   
      #   dat1 <- as.data.frame(cbind(Y_old[S_old[,j]],T_old[S_old[,j]],X_old[S_old[,j]]))
      #   
      #   names(dat1) <- c("Y","Tr","X")
      #   
      #   ps <- sum(dat1$Tr)/nrow(dat1)
      #   
      #   tau_old[j] <-  mean(dat1$Y*dat1$Tr/ e_1.hat_old[j] - dat1$Y*(1-dat1$Tr)/(1-e_1.hat_old[j]))
      #   
      #   sd_old.t[j]<- sd(dat1$Y[dat1$Tr==1])
      #   
      #   sd_old.c[j]<- sd(dat1$Y[dat1$Tr==0])
      # }
      
      S_old_num <- matrix(as.numeric(S_old), nrow = nrow(S_old))
      causal_forest_model <- causal_forest(S_old_num, Y_old, T_old)
      individual_treatment_effects <- predict(causal_forest_model)$predictions
      
      tau_old <- sd_old.t <- sd_old.c <- NULL
      for (j in 1:ncol(S_1)){
        tau_old[j] <- mean(individual_treatment_effects[S_old[,j]])
        sd_old.t[j]<- sd(individual_treatment_effects[S_old[,j] & T_old==1])
        sd_old.c[j]<- sd(individual_treatment_effects[S_old[,j] & T_old==0])
      }
      
      names(tau_old) <- LETTERS[seq(1:m)]
      
      names(sd_old.t) <- names(sd_old.c) <- LETTERS[seq(1:m)]
      
      p_old <- colMeans(S_old)
      tau_opt <- tau_old
      var_opt <- 1/p_old*(sd_old.t^2/e_1.hat_old + sd_old.c^2/(1-e_1.hat_old))
      tau_opt_list[[i+1]] <- tau_opt
      
      
      if (tie && (merge || ((!merge) && (i == num_stage)))){
        ## Find the tie set for \hat{\tau}_{\hat{1}} and find debiased estimates based on the tie set
        # Input:
        #     thetas: \hat{\tau}
        #     s: the estimated variances of \hat{\tau}
        #     nk: the size of each subgroup
        #     n1: first-stage sample size
        #     n: number of subjects enrolled in each of the subsequent stage
        #     num_stage: number of experimental stages
        #     m: the number of subgroups
        #     K: the Kth smallest parameters of interest (the biggest: K = m)
        #     cl: c_L
        #     cr: c_R
        #     B: times of resampling
        #     Y_old: outcomes
        #     X_old: covariates
        #     T_old: treatment
        #     S_old: subgroup memberships
        #     truetie: actual tie set
        # Output:
        #     result: a vecter including correction selection probability, 
        #             and three kinds of debiased estimates based on the tie set
        #     pattern: tie set selected under proposed design
        mm3 <- function(thetas,s,nk,n1,n,num_stage,m,K,cl,cr,B,Y_old,X_old,T_old,S_old,truetie){
          
          N <- n1 + n * num_stage
          them = sort(thetas)[K]
          epsiss = numeric(B)
          tilde.beta = numeric(B)
          w = which(thetas==them); d=1/4; tn = (s[w])^d
          tie <- matrix(NA, nrow = B, ncol = m)
          bl <- N^{-d}*cl*tn
          br <- N^{-d}*cr*tn
          
          for (j in 1:B) {
            set <- NULL
            set <- sample(x=1:N,size=N,replace = TRUE)
            
            Y_old_0 <- Y_old[set]
            X_old_0 <- X_old[set]
            T_old_0 <- T_old[set]
            S_old_0 <- S_old[set,]
            
            # update ATE and standard deviation
            e_1.hat_old_0 <- tau_old_0 <- sd_old.t_0 <- sd_old.c_0 <- NULL
            
            for(k in 1:ncol(S_old_0)){
              e_1.hat_old_0[k] <- sum(T_old_0[S_old_0[,k]])/(sum(S_old_0[,k]))
            }
            
            for (k in 1:ncol(S_old_0)){
              
              dat1_0 <- as.data.frame(cbind(Y_old_0[S_old_0[,k]],T_old_0[S_old_0[,k]],X_old_0[S_old_0[,k]]))
              
              names(dat1_0) <- c("Y","Tr","X")
              
              tau_old_0[k] <-  mean(dat1_0$Y*dat1_0$Tr/ e_1.hat_old_0[k] - dat1_0$Y*(1-dat1_0$Tr)/(1-e_1.hat_old_0[k]))
              
              sd_old.t_0[k]<- sd(dat1_0$Y[dat1_0$Tr==1])
              
              sd_old.c_0[k]<- sd(dat1_0$Y[dat1_0$Tr==0])
            }
            
            names(tau_old_0) <- LETTERS[seq(1:m)]
            
            names(sd_old.t_0) <- names(sd_old.c_0) <- LETTERS[seq(1:m)]
            
            p_old_0 <- colMeans(S_old_0)
            tau_opt_0 <- tau_old_0
            var_opt_0 <- 1/p_old_0*(sd_old.t_0^2/e_1.hat_old_0 + sd_old.c_0^2/(1-e_1.hat_old_0))
            epsi = tau_opt_0
            e = epsi-sort(epsi)[K]
            wkm = numeric(m)
            temp = ((e<=br)+(e>=-bl))==2
            wkm[temp] = 1
            tie[j,] = wkm
            epsiss[j] = sum(wkm*epsi)/sum(wkm)
            tilde.beta[j] = sum(wkm*thetas)/sum(wkm)
          }
          
          # convert tie dataframe rows to character strings
          tieset_strings <- apply(tie, 1, paste, collapse = "")
          # Count the occurrence of each tie set pattern
          string_counts <- table(tieset_strings)
          ratio <- string_counts / B
          
          epsiss.sort = sort(epsiss)
          lower = epsiss.sort[0.025*B] # estimate the 95% confidence interval lower bound
          upper = epsiss.sort[0.975*B] # estimate the 95% confidence interval upper bound
          width = upper-lower
          esti = mean(epsiss.sort)
          esti2 = median(epsiss.sort) # estimate the mth smallest parameter
          mtbeta = median(tilde.beta)
          
          if (sum(names(ratio) == truetie) == 0){
            correct <- 0
          }else{
            correct <- as.numeric(ratio[names(ratio) == truetie])
          }
          
          pattern <- as.numeric(strsplit(names(ratio)[ratio == max(ratio)],split='')[[1]])
          
          return(list(result = data.frame(paste(pattern, collapse = ""), correct, esti, esti2, mtbeta), pattern = pattern))
        }
        
        mm4 <- function(thetas,s,nk,n1,n,num_stage,m,K,cl,cr,B,Y_old,X_old,T_old,S_old,truetie){
          
          N <- n1 + n * num_stage
          them = sort(thetas)[K]
          epsiss = numeric(B)
          tilde.beta = numeric(B)
          w = which(thetas==them); d=1/4; tn = (s[w])^d
          tie <- matrix(NA, nrow = B, ncol = m)
          bl <- N^{-d}*cl*tn
          br <- N^{-d}*cr*tn
          
          for (j in 1:B) {
            set <- NULL
            set <- sample(x=1:n1,size=n1,replace = TRUE)
            index <- cumsum(c(n1,rep(n,num_stage)))
            for (k in 1:num_stage){
              set1 <- NULL
              set1 <- sample(x=(index[k]+1):index[k+1],size=n,replace = TRUE)
              set <- c(set, set1)
            }
            
            Y_old_0 <- Y_old[set]
            X_old_0 <- X_old[set]
            T_old_0 <- T_old[set]
            S_old_0 <- S_old[set,]
            
            # update ATE and standard deviation
            e_1.hat_old_0 <- tau_old_0 <- sd_old.t_0 <- sd_old.c_0 <- NULL
            
            for(k in 1:ncol(S_old_0)){
              e_1.hat_old_0[k] <- sum(T_old_0[S_old_0[,k]])/(sum(S_old_0[,k]))
            }
            
            for (k in 1:ncol(S_old_0)){
              
              dat1_0 <- as.data.frame(cbind(Y_old_0[S_old_0[,k]],T_old_0[S_old_0[,k]],X_old_0[S_old_0[,k]]))
              
              names(dat1_0) <- c("Y","Tr","X")
              
              tau_old_0[k] <-  mean(dat1_0$Y*dat1_0$Tr/ e_1.hat_old_0[k] - dat1_0$Y*(1-dat1_0$Tr)/(1-e_1.hat_old_0[k]))
              
              sd_old.t_0[k]<- sd(dat1_0$Y[dat1_0$Tr==1])
              
              sd_old.c_0[k]<- sd(dat1_0$Y[dat1_0$Tr==0])
            }
            
            names(tau_old_0) <- LETTERS[seq(1:m)]
            
            names(sd_old.t_0) <- names(sd_old.c_0) <- LETTERS[seq(1:m)]
            
            p_old_0 <- colMeans(S_old_0)
            tau_opt_0 <- tau_old_0
            var_opt_0 <- 1/p_old_0*(sd_old.t_0^2/e_1.hat_old_0 + sd_old.c_0^2/(1-e_1.hat_old_0))
            epsi = tau_opt_0
            e = epsi-sort(epsi)[K]
            wkm = numeric(m)
            temp = ((e<=br)+(e>=-bl))==2
            wkm[temp] = 1
            tie[j,] = wkm
            epsiss[j] = sum(wkm*epsi)/sum(wkm)
            tilde.beta[j] = sum(wkm*thetas)/sum(wkm)
          }
          
          # convert tie dataframe rows to character strings
          tieset_strings <- apply(tie, 1, paste, collapse = "")
          # Count the occurrence of each tie set pattern
          string_counts <- table(tieset_strings)
          ratio <- string_counts / B
          
          epsiss.sort = sort(epsiss)
          lower = epsiss.sort[0.025*B] # estimate the 95% confidence interval lower bound
          upper = epsiss.sort[0.975*B] # estimate the 95% confidence interval upper bound
          width = upper-lower
          esti = mean(epsiss.sort)
          esti2 = median(epsiss.sort) # estimate the mth smallest parameter
          mtbeta = median(tilde.beta)
          
          if (sum(names(ratio) == truetie) == 0){
            correct <- 0
          }else{
            correct <- as.numeric(ratio[names(ratio) == truetie])
          }
          
          pattern <- as.numeric(strsplit(names(ratio)[ratio == max(ratio)],split='')[[1]])
          
          return(list(result = data.frame(paste(pattern, collapse = ""), correct, esti, esti2, mtbeta), pattern = pattern))
        }
        
        met3 <- function(thetahh,s,nk,n1,n,num_stage,m,K,cl,cr,B,Y_old,X_old,T_old,S_old){
          
          N <- n1 + n * num_stage
          r = mean(s)/var(thetahh); sig = sqrt(sum(s*nk)/m); d = 1/4
          temp = r*(sig/sqrt(mean(s)))^0.1; tri = min(c(1,temp))
          thetah = tri*mean(thetahh)+(1-tri)*thetahh
          them = sort(thetah)[K]
          
          b <- numeric(B)
          for (i in 1:B) {
            
            set <- NULL
            set <- sample(x=1:N,size=N,replace = TRUE)
            
            Y_old_0 <- Y_old[set]
            X_old_0 <- X_old[set]
            T_old_0 <- T_old[set]
            S_old_0 <- S_old[set,]
            
            # update ATE and standard deviation
            e_1.hat_old_0 <- tau_old_0 <- sd_old.t_0 <- sd_old.c_0 <- NULL
            
            for(k in 1:ncol(S_old_0)){
              e_1.hat_old_0[k] <- sum(T_old_0[S_old_0[,k]])/(sum(S_old_0[,k]))
            }
            
            for (k in 1:ncol(S_old_0)){
              
              dat1_0 <- as.data.frame(cbind(Y_old_0[S_old_0[,k]],T_old_0[S_old_0[,k]],X_old_0[S_old_0[,k]]))
              
              names(dat1_0) <- c("Y","Tr","X")
              
              tau_old_0[k] <-  mean(dat1_0$Y*dat1_0$Tr/ e_1.hat_old_0[k] - dat1_0$Y*(1-dat1_0$Tr)/(1-e_1.hat_old_0[k]))
              
              sd_old.t_0[k]<- sd(dat1_0$Y[dat1_0$Tr==1])
              
              sd_old.c_0[k]<- sd(dat1_0$Y[dat1_0$Tr==0])
            }
            
            names(tau_old_0) <- LETTERS[seq(1:m)]
            
            names(sd_old.t_0) <- names(sd_old.c_0) <- LETTERS[seq(1:m)]
            
            p_old_0 <- colMeans(S_old_0)
            tau_opt_0 <- tau_old_0
            var_opt_0 <- 1/p_old_0*(sd_old.t_0^2/e_1.hat_old_0 + sd_old.c_0^2/(1-e_1.hat_old_0))
            
            epsi = tri*mean(tau_opt_0)+(1-tri)*tau_opt_0
            temp = sort(epsi)[K]; w = which(epsi==temp)
            tn = (s[w])^d
            bl <- N^{-d}*cl*tn
            br <- N^{-d}*cr*tn
            e = epsi-sort(epsi)[K]
            wkm = numeric(m)
            temp = ((e<=br)+(e>=-bl))==2
            wkm[temp] = 1
            epsiss = sum(wkm*epsi)/sum(wkm)
            b[i] = b[i]+ifelse(them-epsiss>=0,1,0)
          }
          
          l = (sum(((sort(b)==1)-mean(sort(b)==1))^2)/B+sum(((sort(b)==0)-mean(sort(b)==0))^2)/B)/2
          return(l)
        }
        
        met4 <- function(thetahh,s,nk,n1,n,num_stage,m,K,cl,cr,B,Y_old,X_old,T_old,S_old){
          
          N <- n1 + n * num_stage
          r = mean(s)/var(thetahh); sig = sqrt(sum(s*nk)/m); d = 1/4
          temp = r*(sig/sqrt(mean(s)))^0.1; tri = min(c(1,temp))
          thetah = tri*mean(thetahh)+(1-tri)*thetahh
          them = sort(thetah)[K]
          
          b <- numeric(B)
          for (i in 1:B) {
            
            set <- NULL
            set <- sample(x=1:n1,size=n1,replace = TRUE)
            index <- cumsum(c(n1,rep(n,num_stage)))
            for (k in 1:num_stage){
              set1 <- NULL
              set1 <- sample(x=(index[k]+1):index[k+1],size=n,replace = TRUE)
              set <- c(set, set1)
            }
            
            Y_old_0 <- Y_old[set]
            X_old_0 <- X_old[set]
            T_old_0 <- T_old[set]
            S_old_0 <- S_old[set,]
            
            # update ATE and standard deviation
            e_1.hat_old_0 <- tau_old_0 <- sd_old.t_0 <- sd_old.c_0 <- NULL
            
            for(k in 1:ncol(S_old_0)){
              e_1.hat_old_0[k] <- sum(T_old_0[S_old_0[,k]])/(sum(S_old_0[,k]))
            }
            
            for (k in 1:ncol(S_old_0)){
              
              dat1_0 <- as.data.frame(cbind(Y_old_0[S_old_0[,k]],T_old_0[S_old_0[,k]],X_old_0[S_old_0[,k]]))
              
              names(dat1_0) <- c("Y","Tr","X")
              
              tau_old_0[k] <-  mean(dat1_0$Y*dat1_0$Tr/ e_1.hat_old_0[k] - dat1_0$Y*(1-dat1_0$Tr)/(1-e_1.hat_old_0[k]))
              
              sd_old.t_0[k]<- sd(dat1_0$Y[dat1_0$Tr==1])
              
              sd_old.c_0[k]<- sd(dat1_0$Y[dat1_0$Tr==0])
            }
            
            names(tau_old_0) <- LETTERS[seq(1:m)]
            
            names(sd_old.t_0) <- names(sd_old.c_0) <- LETTERS[seq(1:m)]
            
            p_old_0 <- colMeans(S_old_0)
            tau_opt_0 <- tau_old_0
            var_opt_0 <- 1/p_old_0*(sd_old.t_0^2/e_1.hat_old_0 + sd_old.c_0^2/(1-e_1.hat_old_0))
            
            epsi = tri*mean(tau_opt_0)+(1-tri)*tau_opt_0
            temp = sort(epsi)[K]; w = which(epsi==temp)
            tn = (s[w])^d
            bl <- N^{-d}*cl*tn
            br <- N^{-d}*cr*tn
            e = epsi-sort(epsi)[K]
            wkm = numeric(m)
            temp = ((e<=br)+(e>=-bl))==2
            wkm[temp] = 1
            epsiss = sum(wkm*epsi)/sum(wkm)
            b[i] = b[i]+ifelse(them-epsiss>=0,1,0)
          }
          
          l = (sum(((sort(b)==1)-mean(sort(b)==1))^2)/B+sum(((sort(b)==0)-mean(sort(b)==0))^2)/B)/2
          return(l)
        }
        
        gam3 = function(B){
          l = numeric(20000)
          for (i in 1:20000) {
            b = rbinom(B, size = 1, prob = 0.5)
            l[i] = (sum(((sort(b)==1)-mean(sort(b)==1))^2)/B+sum(((sort(b)==0)-mean(sort(b)==0))^2)/B)/2
          }
          return(quantile(l,0.975))
        }
        
        gam4 = function(B){
          l = numeric(20000)
          for (i in 1:20000) {
            b = rbinom(B, size = 1, prob = 0.5)
            l[i] = (sum(((sort(b)==1)-mean(sort(b)==1))^2)/B+sum(((sort(b)==0)-mean(sort(b)==0))^2)/B)/2
          }
          return(quantile(l,0.975))
        }
        
        ## Function for finding the lower and upper tuning parameters
        # Input:
        #     cl: choices of c_L
        #     cr: choices of c_R
        #     thetah: \hat{\tau}
        #     s: the estimated variances of \hat{\tau}
        #     nk: the size of each subgroup
        #     n1: first-stage sample size
        #     n: number of subjects enrolled in each of the subsequent stage
        #     num_stage: number of experimental stages
        #     K: the number of subgroups
        #     m: the mth smallest parameters of interest (the biggest: m = K)
        #     gam: threshold
        #     B: times of resampling
        #     Y_old: outcomes
        #     X_old: covariates
        #     T_old: treatment
        #     S_old: subgroup memberships
        # Output:
        #     return the best selected (cl,cr)
        tuningf3 <- function(cl,cr,thetah,s,nk,n1,n,num_stage,m,K,gam,B,Y_old,X_old,T_old,S_old){
          
          l1 = length(cr); l2 = length(cl); d=1/4
          l = matrix(0,nrow = l2,ncol = l1)
          
          for (n1 in 1:l2) {
            c1=cl[n1]
            
            for (n2 in 1:l1) {
              c2=cr[n2]
              l[n1,n2]=met3(thetah,s,nk,n1,n,num_stage,m,K,c1,c2,B,Y_old,X_old,T_old,S_old)
            }
            
          }
          
          if (!any(l<=gam)) {
            mi = min(l); temp = which(l==mi,arr.ind = T)[1,]
            c1 = temp[1]; c2 = temp[2]
            
            return(list(cl = cl[c1],cr = cr[c2]))
          }
          
          temp = which(l<=gam,arr.ind = T)
          c1 = mean(cl[temp[,1]]); c2 = mean(cr[temp[,2]])
          
          return(list(cl = c1,cr = c2))
        }
        
        tuningf4 <- function(cl,cr,thetah,s,nk,n1,n,num_stage,m,K,gam,B,Y_old,X_old,T_old,S_old){
          
          l1 = length(cr); l2 = length(cl); d=1/4
          l = matrix(0,nrow = l2,ncol = l1)
          
          for (n1 in 1:l2) {
            c1=cl[n1]
            
            for (n2 in 1:l1) {
              c2=cr[n2]
              l[n1,n2]=met4(thetah,s,nk,n1,n,num_stage,m,K,c1,c2,B,Y_old,X_old,T_old,S_old)
            }
            
          }
          
          if (!any(l<=gam)) {
            mi = min(l); temp = which(l==mi,arr.ind = T)[1,]
            c1 = temp[1]; c2 = temp[2]
            
            return(list(cl = cl[c1],cr = cr[c2]))
          }
          
          temp = which(l<=gam,arr.ind = T)
          c1 = mean(cl[temp[,1]]); c2 = mean(cr[temp[,2]])
          
          return(list(cl = c1,cr = c2))
        }
        
        ## simulation
        # Input:
        #     n1: first-stage sample size
        #     n: number of subjects enrolled in each of the subsequent stage
        #     m: the number of subgroups
        #     K: the Kth smallest parameters of interest (the biggest: K = m)
        #     true_p: population subgroup proportions
        #     var_old: estimated variances
        #     B: times of resampling
        #     tau_old: estimated subgroup treatment effects
        #     Y_old: outcomes
        #     X_old: covariates
        #     T_old: treatment
        #     S_old: subgroup memberships
        #     truetie: actual tie set
        #     method.identify: dynamic identification of the best subgroups method (1 refers to naive bootstrap; 2 refers to separate bootstrap)
        # Output:
        #     result: a vecter including correction selection probability, 
        #             and three kinds of debiased estimates based on the tie set
        #     pattern: tie set selected under proposed design
        sec.fun.t <- function(n1,n,num_stage,m,K,true_p,var_old,B,tau_old,Y_old,X_old,T_old,S_old,truetie,method.identify){
          
          N <- n1 + n * num_stage
          nk <- round(true_p*N) # the size of each subgroup
          
          ga3 = gam3(5)
          ga4 = gam4(5)
          
          s <- var_old
          theta <- tau_old
          s <- s/N
          thetas <- c(theta)
          
          cl <- seq(18, 21, length.out = 5) 
          cr <- seq(18, 21, length.out = 5)
          
          if (method.identify == 1){
            # naive bootstrap
            temp1 = tuningf3(cl,cr,thetas,s,nk,n1,n,num_stage,m,K,ga3,200,Y_old,X_old,T_old,S_old)
            CL1 = temp1$cl; CR1 = temp1$cr
            temp2 <- mm3(thetas,s,nk,n1,n,num_stage,m,K,CL1,CR1,B,Y_old,X_old,T_old,S_old,truetie)
            result <- temp2$result
            pattern <- temp2$pattern
          }else if (method.identify == 2){
            # separate bootstrap
            temp3 = tuningf4(cl,cr,thetas,s,nk,n1,n,num_stage,m,K,ga3,200,Y_old,X_old,T_old,S_old)
            CL1 = temp3$cl; CR1 = temp3$cr
            temp4 <- mm4(thetas,s,nk,n1,n,num_stage,m,K,CL1,CR1,B,Y_old,X_old,T_old,S_old,truetie)
            result <- temp4$result
            pattern <- temp4$pattern
          }
          
          return(list(result = result, pattern = pattern))
        }
        
        temp.t <- sec.fun.t(n1,n,i,m,m,true_p,var_opt,B,tau_opt,Y_old,X_old,T_old,S_old,truetie,method.identify)
        result <- temp.t$result
        
        if (merge){
          pattern <- temp.t$pattern
          ## Merge subgroups who are ties
          m <- m - sum(pattern==1) +1
          a <- which(pattern==1)
          a0 <- which(pattern==0)
          
          mu1_vec_new <- rep(NA,m)
          mu1_vec_new[a[1]] <- sum(mu1_vec[a]*true_p[a])/sum(true_p[a])
          mu1_vec_new[-a[1]] <- mu1_vec[-a]
          mu1_vec <- mu1_vec_new
          
          mu0_vec_new <- rep(NA,m)
          mu0_vec_new[a[1]] <- sum(mu0_vec[a]*true_p[a])/sum(true_p[a])
          mu0_vec_new[-a[1]] <- mu0_vec[-a]
          mu0_vec <- mu0_vec_new
          
          sd1_vec_new <- rep(NA,m)
          sd1_vec_new[a[1]] <- sqrt(sum(((sd1_vec[a])^2)*((true_p[a])^2)))/sum(true_p[a])
          sd1_vec_new[-a[1]] <- sd1_vec[-a]
          sd1_vec <- sd1_vec_new
          
          sd0_vec_new <- rep(NA,m)
          sd0_vec_new[a[1]] <- sqrt(sum(((sd0_vec[a])^2)*((true_p[a])^2)))/sum(true_p[a])
          sd0_vec_new[-a[1]] <- sd0_vec[-a]
          sd0_vec <- sd0_vec_new
          
          name_new <- rep(NA,m)
          name_new[a[1]] <- paste(sort(strsplit((paste(names(true_p)[a], collapse = "")), split = "")[[1]]), collapse = "")
          name_new[-a[1]] <- names(true_p)[a0]
          tieset[i+1] <- paste(name_new, collapse = ",")
          
          true_p_new <- rep(NA,m)
          true_p_new[a[1]] <- sum(true_p[a])
          true_p_new[-a[1]] <- true_p[-a]
          true_p <- true_p_new
          names(true_p) <- name_new
          
          
          X_old[X_old %in% a] <- a[1]
          for (j in 1:length(unique(X_old))){
            X_old[X_old==sort(unique(X_old))[j]] <- j
          }
          
          S_old <- matrix(NA,nrow = length(X_old),ncol = m)
          for(j in 1:m){
            S_old[,j] <- (X_old == j)
          }
          colnames(S_old) <- LETTERS[seq(1:m)]
          
          # update ATE and standard deviation
          p_old <- e_1.hat_old <-tau_old <- sd_old.t <- sd_old.c <- NULL
          
          p_old <- colMeans(S_old)
          
          for(k in 1:ncol(S_old)){
            e_1.hat_old[k] <- sum(T_old[S_old[,k]])/(sum(S_old[,k]))
          }
          
          for (j in 1:ncol(S_old)){
            
            dat1 <- as.data.frame(cbind(Y_old[S_old[,j]],T_old[S_old[,j]],X_old[S_old[,j]]))
            
            names(dat1) <- c("Y","Tr","X")
            
            ps <- sum(dat1$Tr)/nrow(dat1)
            
            tau_old[j] <-  mean(dat1$Y*dat1$Tr/ e_1.hat_old[j] - dat1$Y*(1-dat1$Tr)/(1-e_1.hat_old[j]))
            
            sd_old.t[j]<- sd(dat1$Y[dat1$Tr==1])
            
            sd_old.c[j]<- sd(dat1$Y[dat1$Tr==0])
          }
          
          names(tau_old) <- LETTERS[seq(1:m)]
          
          names(sd_old.t) <- names(sd_old.c)<- LETTERS[seq(1:m)]
          
          tau_opt <- tau_old
          var_opt <- 1/p_old*(sd_old.t^2/e_1.hat_old + sd_old.c^2/(1-e_1.hat_old))
          tietau[i+1] <- tau_opt[a[1]]
        }
      }
    }
  }
  
  ################### Estimation after final stage ###################
  sd_opt <- sqrt(var_opt)
  hi_opt <- tau_opt + 1.96*sd_opt/sqrt(n1+n*num_stage)
  lo_opt <- tau_opt - 1.96*sd_opt/sqrt(n1+n*num_stage)
  
  
  if (!tie){
    return(list(tau_opt_list = tau_opt_list,
                esti = c(tau_opt[which(tau_opt == max(tau_opt))],
                         sd_opt[which(tau_opt == max(tau_opt))],
                         hi_opt[which(tau_opt == max(tau_opt))],
                         lo_opt[which(tau_opt == max(tau_opt))])))
  }else if (tie && (!merge)){
    return(result)
  }else if (tie && merge){
    return(list(tieset = tieset, tietau = tietau,
                esti = c(tau_opt[which(tau_opt == max(tau_opt))],
                         sd_opt[which(tau_opt == max(tau_opt))],
                         hi_opt[which(tau_opt == max(tau_opt))],
                         lo_opt[which(tau_opt == max(tau_opt))])))
  }
  
}



################### (3) Main Function Call ###################
main.fun <- function(Nrep,n1,n,c1,m,K,true_p,mu1_vec,mu0_vec,sd1_vec,sd0_vec,num_stage,B,truetie,mab=FALSE,tie=TRUE,merge=TRUE,method.select=NULL,method.identify=NULL){
  
  if (!tie){
    tau_opt_list <- list()
    for (k in 1:(1+num_stage)){
      tau_opt_list[[k]] <- matrix(NA, nrow = Nrep, ncol = m)
    }
    esti <- matrix(NA, nrow = Nrep, ncol = 4)
    for (i in 1:Nrep) {
      res <- simMultiStageRAR(n1,n,c1,m,K,true_p,mu1_vec,mu0_vec,sd1_vec,sd0_vec,num_stage,B,truetie,mab,tie,merge,method.select,method.identify)
      for (j in 1:(1+num_stage)){
        tau_opt_list[[j]][i,] <- res$tau_opt_list[[j]]
      }
      esti[i,] <- res$esti
    }
    result <- list(tau_opt_list = tau_opt_list, esti = esti)
  }else if (tie && (!merge)){
    res <- data.frame(matrix(NA, nrow = Nrep, ncol = 5))
    for (i in 1:Nrep) {
      res[i,] <- simMultiStageRAR(n1,n,c1,m,K,true_p,mu1_vec,mu0_vec,sd1_vec,sd0_vec,num_stage,B,truetie,mab,tie,merge,method.select,method.identify)
    }
    
    string_counts <- table(res[,1])
    ratio <- string_counts / Nrep
    if (sum(names(ratio) == truetie) == 0){
      correct <- 0
    }else{
      correct <- as.numeric(ratio[names(ratio) == truetie])
    }
    
    result <- list(correct = correct, ratio = ratio, res = res)
    
  }else if (tie && merge){
    tieset <- matrix(NA, nrow = Nrep, ncol = (num_stage + 1))
    tietau <- matrix(NA, nrow = Nrep, ncol = (num_stage + 1))
    esti <- matrix(NA, nrow = Nrep, ncol = 4)
    for (i in 1:Nrep) {
      temp <- list(tieset = numeric(1+num_stage), tietau = numeric(1+num_stage), esti = numeric(4))
      skip_to_next <- FALSE
      tryCatch(temp <- simMultiStageRAR(n1,n,c1,m,K,true_p,mu1_vec,mu0_vec,sd1_vec,sd0_vec,num_stage,B,truetie,mab,tie,merge,method.select,method.identify),
               error = function(e) {skip_to_next <- TRUE})
      tieset[i,] <- temp$tieset
      tietau[i,] <- temp$tietau
      esti[i,] <- temp$esti
    }
    
    tieset <- tieset[tieset[,1] != "0",]
    tietau <- tietau[tietau[,1] != "0",]
    esti <- esti[esti[,1] != "0",]
    
    true <- as.numeric(strsplit(truetie,split='')[[1]])
    a <- which(true == 1)
    a0 <- which(true == 0)
    name <- LETTERS[seq(1:m)]
    m <- m - sum(true == 1) + 1
    name_new <- rep(NA,m)
    name_new[a[1]] <- paste(name[a], collapse = "")
    name_new[-a[1]] <- name[a0]
    truetie <- paste(name_new, collapse = ",")
    
    correct <- numeric(1+num_stage)
    for (j in 1:(1+num_stage)){
      string_counts <- table(tieset[,j])
      ratio <- string_counts / length(tieset[,j])
      if (sum(names(ratio) == truetie) == 0){
        correct[j] <- 0
      }else{
        correct[j] <- as.numeric(ratio[names(ratio) == truetie])
      }
    }
    
    result <- list(correct = correct, tieset = tieset, tietau = tietau, esti = esti)
    
  }
  
  return(result)
  
}


## setting
n <- 400; n1 <- 400 # sample size
m <- 5 # number of subgroups
num_stage <- 14
N <- n + n1 * num_stage
tau <- c(-2.769924, 10.531104, -1.212645, 10.886470, -1.458249)
true_p <- c(0.2788462, 0.1250000, 0.2980769, 0.1121795, 0.1858974)
names(true_p) <- LETTERS[seq(1:m)]
mu1_vec <- c(42.56808, 50.44094, 44.36636, 44.30218, 37.70726)
mu0_vec <- c(45.33801, 39.90983, 45.57901, 33.41571, 39.16550)
sd1_vec <- c(10.84663, 12.29309, 12.64410, 14.28020, 14.64300)
sd0_vec <- c(11.49844, 15.18284, 14.56797, 13.08778, 15.06209)

truetie <- "01010"
true <- c(0,1,0,1,0)
tietau <- sum(tau[which(true == 1)]*true_p[which(true == 1)])/sum(true_p[which(true == 1)])


B <- 2000 # times of resampling
Nrep <- 200 # specify simulation parameters
c1 <- 0.8 # cost constraint cl



# Causal tree
set.seed(1001)
result03 <- main.fun(Nrep,n1,n,c1,m,m,true_p,mu1_vec,mu0_vec,sd1_vec,sd0_vec,num_stage,B,truetie,0,F,F)
save(result03, file = "result03.RData")


# results
colMeans(result03$esti)

sqrt(N) * mean(abs(result03$esti[,1] - tau[4]))
