
library(dplyr)
library(tidyverse)
library(ggplot2)
library(nloptr)

set.seed(193717)

# true parameters: log(Y)
mean_Y0.vec <- c(2.98,3.02,2.59,2.47,2.92,2.84)
mean_Y1.vec <- c(2.50,3.03,2.94,2.72,2.68,3.13)
sd_Y0.vec <- c(2.06,1.70,0.48,0.31,0.85,0.78)
sd_Y1.vec <- c(0.36,2.06,1.27,0.82,0.66,2.01)

nStage <- 8
rt <- 1/nStage

rho_1 <- list(
  "X=0" = c(0.6, 0.92, 0.98, 1),  # Probabilities for X=0
  "X=1" = c(0.69, 0.89, 0.98, 1)    # Probabilities for X=1
)

rho_0 <- list(
  "X=0" = c(0.67, 0.93, 1, 1),  # Probabilities for X=0
  "X=1" = c(0.60, 0.86, 0.97, 1)    # Probabilities for X=1
)


GenX <- function(n){
  X <- rbinom(n,1,0.64)
  return(X)
}

GenS <- function(n){
  S <- sample(1:3,n,replace=TRUE,prob = c(0.34, 0.34, 0.32))
  return(S)
}

GenY <- function(A,X,S){
  
  if(A==0){
    index <- (A*2+X)*3 +S
    mean <- mean_Y0.vec[index]
    sd <- sd_Y0.vec[index]
  }else{
    index <- (A*2+X)*3 +S-6
    mean <- mean_Y1.vec[index]
    sd <- sd_Y1.vec[index]
  }
  
  Y <- rnorm(1,mean,sd)
  
  return(Y)
}


GenD <- function(A,X){
  
  if(A==1){
    probs <- rho_1[[paste0("X=", X)]] 
  }else{
    probs <- rho_0[[paste0("X=", X)]] 
  }
  
  probs <- c(probs[1],probs[2]-probs[1], probs[3]-probs[2],probs[4]-probs[3])
  
  samples <- rmultinom(1, size = 1, prob = probs)
  
  D <- apply(samples, 2, function(x) {
    which.max(x) - 1
  })
  
  return(D)
}

tauXS <- function(dat,a,x,s){
  
  sub.dat <-  dat %>% filter(A==a & X==x & S==s)
  
  return(mean(sub.dat$Y,na.rm = TRUE))
  
}

tauX <- function(dat,a,x){
  
  sub.dat <-  dat %>% filter(A==a & X==x)
  
  return(mean(sub.dat$Y,na.rm = TRUE))
  
}


rho1X <- function(delay.1,x){
  if(x==1){
    return(rev(delay.1[2,]))
  }else{
    return(rev(delay.1[1,]))
  }
}

rho0X <- function(delay.0,x){
  if(x==1){
    return(rev(delay.0[2,]))
  }else{
    return(rev(delay.0[1,]))
  }
}


simExperiment <- function(nStage,n_t){
  
  delay.1 <- delay.0 <- matrix(NA,nrow=2,ncol=4)
  
  X.mat <- A.mat <- D.mat <- S.mat <- matrix(NA,nrow=4,ncol=n_t)
  Y.mat.full <- matrix(NA,nrow=4,ncol=n_t)
  Y.observed <- NULL
  R.vec <-  NULL
  e.hat <- NULL
  
  
  ## simulate Stage 1 to (D*+1)=4 experiment
  for(t in 1:4){
    
    X.t <- GenX(n_t) # generate covariates
    X.mat[t,] <- X.t
    
    A.t <- rbinom(n_t,1,0.5) # assign treatments
    A.mat[t,] <- A.t
    
    e.hat[t] <- sum(A.t)/n_t
    
    R.t <- rep(t,length(X.t))
    R.vec <- c(R.vec,R.t)
    
    D.t <- NULL
    for(i in 1:length(X.t)){ # generate delay variable
      D.t[i] <- GenD(A.t[i],X.t[i]) 
    }
    D.mat[t,] <- D.t
    
    S.t <- GenS(n_t) # observe surrogates
    S.mat[t,] <- S.t
    
    Y.t <- NULL
    for(i in 1:length(X.t)){
      Y.t[i] <- GenY(A.t[i],X.t[i],S.t[i]) # observe 
    }
    Y.mat.full[t,] <- Y.t
    
    
    Y.t[D.t>0] <- NA # generate partial observed outcomes 
    Y.observed <- c(Y.observed,Y.t)
    
    
    ## When t >= 1, update Y.observed that has D=1 in Stage 1
    if(t>1){
      Y.update <- Y.observed[1:(n_t*(t-1))] # stage 1 Y needs to be updated
      
      D.vec <- unlist(as.data.frame(t(D.mat[1:t-1,])))
      Y.vec <- unlist(as.data.frame(t(Y.mat.full[1:t-1,])))
      
      for(j in 1:length(Y.update)){
        Y.update[j] <- ifelse(R.vec[j]+D.vec[j]<=t,Y.vec[j], NA) # update all the Y's with D<=1
      }
      
      Y.observed[1:(n_t*(t-1))] <-   Y.update
    }
    
    
  } # End of T=4
  
  df.t4 <- data.frame(cbind(Y.observed,
                            R.vec,
                            unlist(as.data.frame(t(D.mat[1:4,]))),
                            unlist(as.data.frame(t(A.mat[1:4,]))),
                            unlist(as.data.frame(t(X.mat[1:4,]))),
                            unlist(as.data.frame(t(S.mat[1:4,])))))
  
  names(df.t4) <- c("Y","R","D","A","X","S")
  
  ## compute delay mechanism
  for(d in 1:4){ # For D = 0,1,2,3
    
    res.1 <- df.t4  %>% 
      filter(A == 1) %>%
      group_by(X) %>%
      summarise(mean(D<= 4-d)) %>% c()
    
    res.0 <- df.t4  %>% 
      filter(A == 0) %>%
      group_by(X) %>%
      summarise(mean(D<=4-d)) %>% c()
    
    delay.1[,4-d+1] <-  as.vector(res.1$`mean(D <= 4 - d)`)
    
    delay.0[,4-d+1] <- as.vector(res.0$`mean(D <= 4 - d)`)
  }
  
  
  full.dat <- df.t4 %>% filter(!is.na(Y))
  tau1.XS <- tau0.XS <- tau1.X <- tau0.X <- rho1.X <- rho0.X <- NULL
  
  for(i in 1:nrow(full.dat)){
    
    tau1.XS[i] <- tauXS(full.dat,1,full.dat$X[i],full.dat$S[i])
    tau0.XS[i] <- tauXS(full.dat,0,full.dat$X[i],full.dat$S[i])
    
    tau1.X[i] <- tauX(full.dat,1,full.dat$X[i])
    tau0.X[i] <- tauX(full.dat,0,full.dat$X[i])
    
  }
  
  A <- full.dat$A
  Y <- full.dat$Y
  X <- full.dat$X
  

  e.tilde.star <- rep(1/2,4)
  
  ## Stage 5 to 8
  for(s in 1:4){
    X.s <- GenX(n_t) # generate covariates
    X.mat <-rbind(X.mat,X.s)
    
    A.s <- rbinom(n_t,1,e.tilde.star[s]) # assign treatments
    A.mat <- rbind(A.mat,A.s)
    
    e.hat[s+4] <- sum(A.s)/n_t
    
    R.s <- rep(s+4,length(X.s))
    R.vec <- c(R.vec,R.s)
    
    D.s <- NULL
    for(i in 1:length(X.s)){ # generate delay variable
      D.s[i] <- GenD(A.s[i],X.s[i]) 
    }
    
    D.mat <- rbind(D.mat,D.s)
    
    S.s <- GenS(n_t) # observe surrogates
    S.mat<- rbind(S.mat,S.s)
    
    Y.s <- NULL
    for(i in 1:length(X.s)){
      Y.s[i] <- GenY(A.s[i],X.s[i],S.s[i]) # observe 
    }
    Y.mat.full <- rbind(Y.mat.full,Y.s)
    
    
    Y.s[D.s>0] <- NA # generate partial observed outcomes 
    Y.observed <- c(Y.observed,Y.s)
    
    
    ## update Y.observed that has D=1 in Stage 1
    Y.update <- Y.observed[1:(n_t*(s+4-1))] # stage 1 Y needs to be updated
    
    D.vec <- unlist(as.data.frame(t(D.mat[1:(s+4-1),])))
    Y.vec <- unlist(as.data.frame(t(Y.mat.full[1:(s+4-1),])))
    
    for(j in 1:length(Y.update)){
      Y.update[j] <- ifelse(R.vec[j]+D.vec[j]<=t,Y.vec[j], NA) # update all the Y's with D<=1
    }
    
    Y.observed[1:(n_t*(s+4-1))] <-   Y.update
    
    
  } # End of T=8
  
  
  
  df.t8 <- data.frame(cbind(Y.observed,
                            R.vec,
                            unlist(as.data.frame(t(D.mat[1:8,]))),
                            unlist(as.data.frame(t(A.mat[1:8,]))),
                            unlist(as.data.frame(t(X.mat[1:8,]))),
                            unlist(as.data.frame(t(S.mat[1:8,])))))
  
  names(df.t8) <- c("Y","R","D","A","X","S")
  
  ## compute delay mechanism
  for(d in 1:4){ # For D = 0,1,2,3
    
    res.1 <- df.t8  %>% 
      filter(A == 1) %>%
      group_by(X) %>%
      summarise(mean(D<= 4-d)) %>% c()
    
    res.0 <- df.t8  %>% 
      filter(A == 0) %>%
      group_by(X) %>%
      summarise(mean(D<=4-d)) %>% c()
    
    delay.1[,4-d+1] <-  as.vector(res.1$`mean(D <= 4 - d)`)
    
    delay.0[,4-d+1] <- as.vector(res.0$`mean(D <= 4 - d)`)
  }
  
  
  
  all.dat <- df.t8 %>% filter(!is.na(Y))
  tau1.XS <- tau0.XS <- tau1.X <- tau0.X <- rho1.X <- rho0.X <- NULL
  
  for(i in 1:nrow(all.dat)){
    
    tau1.XS[i] <- tauXS(all.dat,1,all.dat$X[i],all.dat$S[i])
    tau0.XS[i] <- tauXS(all.dat,0,all.dat$X[i],all.dat$S[i])
    
    tau1.X[i] <- tauX(all.dat,1,all.dat$X[i])
    tau0.X[i] <- tauX(all.dat,0,all.dat$X[i])
    
  }
  
  X <- all.dat$X
  Y <- all.dat$Y
  A <- all.dat$A
  
  et <- e.hat
  
  EIF <- tau1.X - tau0.X +
    A*(Y-tau1.XS)/(X*sum(rt*et*rho1X(delay.1,1)) + (1-X)*sum(rt*et*rho1X(delay.1,0))) + 
    (1-A)*(Y-tau0.XS)/(X*sum(rt*(1-et)*rho0X(delay.0,1)) + (1-X)*sum(rt*(1-et)*rho0X(delay.0,0)) )+
    A*(tau1.XS-tau1.X)/(sum(rt*et)) + (1-A)*(tau0.XS-tau0.X)/(sum(rt*(1-et)))
  
  tau.hat <- mean(EIF)
  
  return(tau.hat)
}


n_t.vec <- seq(50, 200, length.out = 10)
Nrep <- 400
res <- matrix(NA,nrow=Nrep,ncol=length(n_t.vec))
for(j in 1:length(n_t.vec)){
  print(j)
  n_t <- n_t.vec[j]
  for(k in 1:Nrep){
    
    res[k,j] <- simExperiment(8,n_t)
  }
}



mu.cr <- colMeans(res) 
sd.cr <- apply(res,2,sd) 

