rm(list=ls())
#####generation of train data#####
n<-5000
X1<-runif(n)
X2<-runif(n)

prop<-1/(1+exp(X1)+exp(X2))
A_<- rbinom(n,size = 1, prob = prop)
ini_pr_A_S_1_given_X<-prop
ini_pr_A_S_0_given_X<-1-prop

# Calculate pr(A = 1 | A* = 1, X) pr(A = 1 | A* = 0, X)
pr_A_1_given_A_1_X <- plogis(1 + 2 * X1)

pr_A_1_given_A_0_X <- plogis(-1 + 0.5 * X1 - X2)
ini_pr_A_1_given_A_star_1_X<-pr_A_1_given_A_1_X
ini_pr_A_1_given_A_star_0_X<-pr_A_1_given_A_0_X

A <- ifelse(A_ == 1, 
            rbinom(n, size = 1, prob = pr_A_1_given_A_1_X), 
            rbinom(n, size = 1, prob = pr_A_1_given_A_0_X))
# Calculate pr(Z = 1 | A* = 1, X) pr(Z = 1 | A* = 0, X)
pr_Z_1_given_A_1_X <- plogis(-1 - X1 + 0.5 * X2)
pr_Z_1_given_A_0_X <- plogis(2 + X1)
ini_pr_Z_1_given_A_star_1_X<-pr_Z_1_given_A_1_X
ini_pr_Z_1_given_A_star_0_X<-pr_Z_1_given_A_0_X

Z <- ifelse(A_ == 1, 
            rbinom(n, size = 1, prob = pr_Z_1_given_A_1_X), 
            rbinom(n, size = 1, prob = pr_Z_1_given_A_0_X))

sgm<-rnorm(n)
Y <- sin(3.1415926*X1)+(A_-0.5)*(X1+X2)+sgm
ini_E_Y_given_A_star_1_X<-sin(3.1415926*X1)+(1-0.5)*(X1+X2)+sgm
ini_E_Y_given_A_star_0_X<-sin(3.1415926*X1)+(0-0.5)*(X1+X2)+sgm
or_phi1<-(A-ini_pr_A_1_given_A_star_0_X)*(Z-ini_pr_Z_1_given_A_star_0_X)*(Y-ini_E_Y_given_A_star_1_X)/((ini_pr_A_1_given_A_star_1_X-ini_pr_A_1_given_A_star_0_X)*(ini_pr_Z_1_given_A_star_1_X-ini_pr_Z_1_given_A_star_0_X)*(ini_pr_A_S_1_given_X))+ini_E_Y_given_A_star_1_X
or_phi0<-(A-ini_pr_A_1_given_A_star_1_X)*(Z-ini_pr_Z_1_given_A_star_1_X)*(Y-ini_E_Y_given_A_star_0_X)/((ini_pr_A_1_given_A_star_0_X-ini_pr_A_1_given_A_star_1_X)*(ini_pr_Z_1_given_A_star_0_X-ini_pr_Z_1_given_A_star_1_X)*(ini_pr_A_S_0_given_X))+ini_E_Y_given_A_star_0_X


# library(gplm)
data <- data.frame( X1 = X1, X2 = X2, Y = Y,Z=Z,A=A,A_star = A_)
data[,c('A','X1','X2','Y')]<-lapply(data[,c('A','X1','X2','Y')],as.numeric)
data_train<-data
#set 1
p0_A_star_1_given_X=runif(n);p0_A_star_0_given_X=1-p0_A_star_1_given_X
p0_Y_given_A_star_1_X=runif(n);p0_Y_given_A_star_0_X=runif(n)
p0_A_1_given_A_star_1_X=runif(n,0.5,1);p0_A_0_given_A_star_1_X=1-p0_A_1_given_A_star_1_X
p0_A_1_given_A_star_0_X=runif(n,0,0.5);p0_A_0_given_A_star_0_X=1-p0_A_1_given_A_star_0_X
p0_Z_1_given_A_star_1_X=runif(n,0,0.5);p0_Z_0_given_A_star_1_X=1-p0_Z_1_given_A_star_1_X
p0_Z_1_given_A_star_0_X=runif(n,0.5,1);p0_Z_0_given_A_star_0_X=1-p0_Z_1_given_A_star_0_X

kx<-x <- data[,c('X1','X2')]
a<-data[,c('A')]
# dyn.load("kernel_regression.so")
A_one=1
A_a<-A_one-A
A_a<-as.data.frame(A_a)
K <-kernels(A_a)
# result <- .C("r_kernels", A_a)
Ksum <- sum(K)
pa<-K/(max(K))#+min(K)
# for(i in 1:10){
kernel_y<-get_pre_y(as.matrix(x),a,as.matrix(kx))
colnames(kernel_y)<-c(colnames(x),'y_')
pr_A_1_given_X<-unlist(kernel_y[,ncol(kernel_y)])#0.3556
pr_A_0_given_X<-1-pr_A_1_given_X
pr_A_given_X<-ifelse(A==1,pr_A_1_given_X,pr_A_0_given_X)

z<-data[,c('Z')]
Z_one=1
Z_a<-Z_one-Z
Z_a<-as.data.frame(Z_a)
K <-kernels(Z_a)
Ksum <- sum(K)
pz<-K/(max(K))#+min(K)
kernel_y<-get_pre_y(as.matrix(x),z,as.matrix(kx))
colnames(kernel_y)<-c(colnames(x),'y_')

# kernel_y<-get_pre_y(x,z,kx)
pr_Z_1_given_X<-unlist(kernel_y[,ncol(kernel_y)])
# mean(pr_Z_1_given_X[Z==1]);mean(pr_Z_1_given_X[Z==0])
pr_Z_0_given_X<-1-pr_Z_1_given_X
pr_Z_given_X<-ifelse(Z==1,pr_Z_1_given_X,pr_Z_0_given_X)
Y=data[,c('Y')]
x<-kx<- data[,c('X1','X2')]
pr_Y_given_X<-c()
for(x_estimate in 1:length(Y)){
  # X_x <-  x_estimate - x
  if(x_estimate%%100==0) print(x_estimate)
  Y_estimate=Y[x_estimate]
  Y_y<-Y_estimate-Y
  Y_y<-as.data.frame(Y_y)
  # K <-kernels(Y_y)
  kxx<-ifelse(Y_y==0,1,0)
  results <-get_pre_y(as.matrix(x),kxx,as.matrix(kx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  
  # final <- do.call('c',results)
  results_<-c()
  results_<-rbind(results_,results)
  
  pr_Y_given_X<-c(pr_Y_given_X,unlist(results_[,ncol(results_)]))
  
}
pr_Y_given_X_save<-pr_Y_given_X
pp<-c()
for(i in 1:nrow(data)){
  # print(sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]))
  pp<-c(pp,sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]))
  # pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]<-pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]/sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))])
}
pr_Y_given_X_save<-pr_Y_given_X_save/mean(pp)
pr_Y_given_X__<-c()
for(x_estimate in 1:length(Y)){
  pr_Y_given_X_<-pr_Y_given_X_save[((x_estimate-1)*n+1):((x_estimate)*n)]
  pr_Y_given_X_<-pr_Y_given_X_[x_estimate]
  # print(pr_Y_given_X_);print(max(pr_Y_given_X_))
  pr_Y_given_X__<-c(pr_Y_given_X__,pr_Y_given_X_)
}

pr_Y_given_X<-pr_Y_given_X__

pr_A_S_1_given_X=p0_A_star_1_given_X
pr_A_S_0_given_X=p0_A_star_0_given_X
pr_Y_given_A_star_1_X=p0_Y_given_A_star_1_X
pr_Y_given_A_star_0_X=p0_Y_given_A_star_0_X
pr_A_given_A_star_1_X<-ifelse(A==1,p0_A_1_given_A_star_1_X,p0_A_0_given_A_star_1_X)
pr_Z_given_A_star_1_X<-ifelse(Z==1,p0_Z_1_given_A_star_1_X,p0_Z_0_given_A_star_1_X)
pr_A_given_A_star_0_X<-ifelse(A==1,p0_A_1_given_A_star_0_X,p0_A_0_given_A_star_0_X)
pr_Z_given_A_star_0_X<-ifelse(Z==1,p0_Z_1_given_A_star_0_X,p0_Z_0_given_A_star_0_X)
pi_save<-0
t=1;o=1
while(t > 0.001){
  o=o+1
  # print(o)
  pA1<-pr_A_S_1_given_X*pr_Y_given_A_star_1_X*pr_A_given_A_star_1_X*pr_Z_given_A_star_1_X
  pA0<-pr_A_S_0_given_X*pr_Y_given_A_star_0_X*pr_A_given_A_star_0_X*pr_Z_given_A_star_0_X
  
  pi<-pA1/(pA1+pA0)
  kx<-x <-data[,c('X1','X2')]
  kernel_y<-get_pre_y(as.matrix(x),pi,as.matrix(x))
  colnames(kernel_y)<-c(colnames(x),'y_')
  pr_A_S_1_given_X<-unlist(kernel_y[,ncol(kernel_y)])
  pr_A_S_0_given_X=1-pr_A_S_1_given_X
  kx<-x<-data[,c('Y','X1','X2')]
  kernel_y<-get_pre_y(as.matrix(x),pi,as.matrix(x))
  colnames(kernel_y)<-c(colnames(x),'y_')
  pr_A_S_1_given_Y_X<-unlist(kernel_y[,ncol(kernel_y)])
  pr_A_S_0_given_Y_X=1-pr_A_S_1_given_Y_X
  x<-data[data$A==0,c('X1','X2')]
  pi_<-pi[data$A==0]
  kx<-data[,c('X1','X2')]
  kernel_y<-get_pre_y(as.matrix(x),pi_,as.matrix(kx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  pr_A_S_1_given_A_0_X<-unlist(kernel_y[,ncol(kernel_y)])
  pr_A_S_0_given_A_0_X=1-pr_A_S_1_given_A_0_X
  x<-data[data$Z==1,c('X1','X2')]
  pi_<-pi[data$Z==1]
  kx<-data[,c('X1','X2')]
  kernel_y<-get_pre_y(as.matrix(x),pi_,as.matrix(kx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  pr_A_S_1_given_Z_1_X<-unlist(kernel_y[,ncol(kernel_y)])
  pr_A_S_0_given_Z_1_X=1-pr_A_S_1_given_Z_1_X
  pr_Y_given_A_star_0_X<-pr_Y_given_X*pr_A_S_0_given_Y_X/pr_A_S_0_given_X
  pr_Y_given_A_star_1_X<-pr_Y_given_X*pr_A_S_1_given_Y_X/pr_A_S_1_given_X
  pr_A_0_given_A_star_0_X<-pr_A_0_given_X*pr_A_S_0_given_A_0_X/pr_A_S_0_given_X
  pr_A_1_given_A_star_0_X=1-pr_A_0_given_A_star_0_X
  pr_A_given_A_star_0_X<-ifelse(A==1,pr_A_1_given_A_star_0_X,pr_A_0_given_A_star_0_X)
  pr_A_0_given_A_star_1_X<-pr_A_0_given_X*pr_A_S_1_given_A_0_X/pr_A_S_1_given_X
  pr_A_1_given_A_star_1_X=1-pr_A_0_given_A_star_1_X
  pr_A_given_A_star_1_X<-ifelse(A==1,pr_A_1_given_A_star_1_X,pr_A_0_given_A_star_1_X)
  pr_Z_1_given_A_star_0_X<-pr_Z_1_given_X*pr_A_S_0_given_Z_1_X/pr_A_S_0_given_X
  pr_Z_0_given_A_star_0_X=1-pr_Z_1_given_A_star_0_X
  pr_Z_given_A_star_0_X<-ifelse(Z==1,pr_Z_1_given_A_star_0_X,pr_Z_0_given_A_star_0_X)
  pr_Z_1_given_A_star_1_X<-(pr_Z_1_given_X*pr_A_S_1_given_Z_1_X)/pr_A_S_1_given_X
  pr_Z_0_given_A_star_1_X<-1-pr_Z_1_given_A_star_1_X
  pr_Z_given_A_star_1_X<-ifelse(Z==1,pr_Z_1_given_A_star_1_X,pr_Z_0_given_A_star_1_X)
  t=max(abs(pi-pi_save))
  pi_save<-pi
  # print(t)
  
}
kx<-c()
kx<-generate_kx(data[,c('Y','X1','X2')])
x<-data[,c('Y','X1','X2')]
t<-10
kernel_y_<-c()
for(i in 1:t){
  kxx<-kx[((i-1)*(nrow(kx)/t)+1):((i)*(nrow(kx)/t)),]
  kxx<-as.data.frame(kxx)
  kernel_y<-get_pre_y(as.matrix(x),pi,as.matrix(kxx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  
  kernel_y_<-rbind(kernel_y_,kernel_y)
  print(i)
}
p1<-unlist(kernel_y_[,ncol(kernel_y_)])
p0=1-p1
pr_Y_given_X<-pr_Y_given_X_save
# }
ii<-cbind(pr_Y_given_X,p1,p0)
nrow_data<-nrow(data)
li<-fast_vector_operations(as.matrix(ii),pr_A_S_1_given_X,pr_A_S_0_given_X,nrow_data)
pr_Y_given_A_star_1_X_all<-li$pr_Y_given_A_star_1_X_all
pr_Y_given_A_star_0_X_all<-li$pr_Y_given_A_star_0_X_all
E_Y_given_A_star_1_X<-c();E_Y_given_A_star_0_X<-c()

for(j in 1:(length(pr_Y_given_A_star_1_X_all)/nrow(data))){
  o_1<-pr_Y_given_A_star_1_X_all[((j-1)*nrow(data)+1):(j*nrow(data))]
  o_2<-pr_Y_given_A_star_0_X_all[((j-1)*nrow(data)+1):(j*nrow(data))]
  o_1<-o_1/sum(o_1)
  o_2<-o_2/sum(o_2)
  E_Y_given_A_star_1_X<-c(E_Y_given_A_star_1_X,sum(o_1*Y,na.rm = T))
  E_Y_given_A_star_0_X<-c(E_Y_given_A_star_0_X,sum(o_2*Y,na.rm = T))
  
}
phi1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)/((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))+E_Y_given_A_star_1_X
phi0<-(A-pr_A_1_given_A_star_1_X)*(Z-pr_Z_1_given_A_star_1_X)*(Y-E_Y_given_A_star_0_X)/((pr_A_1_given_A_star_0_X-pr_A_1_given_A_star_1_X)*(pr_Z_1_given_A_star_0_X-pr_Z_1_given_A_star_1_X)*(pr_A_S_0_given_X))+E_Y_given_A_star_0_X
numerator_1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)
denominator_1<-((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))
add_1<-E_Y_given_A_star_1_X
numerator_0<-(A-pr_A_1_given_A_star_1_X)*(Z-pr_Z_1_given_A_star_1_X)*(Y-E_Y_given_A_star_0_X)
denominator_0<-((pr_A_1_given_A_star_0_X-pr_A_1_given_A_star_1_X)*(pr_Z_1_given_A_star_0_X-pr_Z_1_given_A_star_1_X)*(pr_A_S_0_given_X))
add_0<-E_Y_given_A_star_0_X
phi = phi1-phi0
truth_Y<-(X1+X2)
X=data[,c('X1','X2')]
Y<-phi

or_phi = or_phi1-or_phi0

save(data_train,truth_Y,X,A,Z,Y,phi,or_phi,file='train_5000_con.RData')

#####generation of test data#####
n<-5000
X1<-runif(n)
X2<-runif(n)

prop<-1/(1+exp(X1)+exp(X2))
A_<- rbinom(n,size = 1, prob = prop)
ini_pr_A_S_1_given_X<-prop
ini_pr_A_S_0_given_X<-1-prop

# Calculate pr(A = 1 | A* = 1, X) pr(A = 1 | A* = 0, X)
pr_A_1_given_A_1_X <- plogis(1 + 2 * X1)

pr_A_1_given_A_0_X <- plogis(-1 + 0.5 * X1 - X2)
ini_pr_A_1_given_A_star_1_X<-pr_A_1_given_A_1_X
ini_pr_A_1_given_A_star_0_X<-pr_A_1_given_A_0_X

A <- ifelse(A_ == 1, 
            rbinom(n, size = 1, prob = pr_A_1_given_A_1_X), 
            rbinom(n, size = 1, prob = pr_A_1_given_A_0_X))
# Calculate pr(Z = 1 | A* = 1, X) pr(Z = 1 | A* = 0, X)
pr_Z_1_given_A_1_X <- plogis(-1 - X1 + 0.5 * X2)
pr_Z_1_given_A_0_X <- plogis(2 + X1)
ini_pr_Z_1_given_A_star_1_X<-pr_Z_1_given_A_1_X
ini_pr_Z_1_given_A_star_0_X<-pr_Z_1_given_A_0_X

Z <- ifelse(A_ == 1, 
            rbinom(n, size = 1, prob = pr_Z_1_given_A_1_X), 
            rbinom(n, size = 1, prob = pr_Z_1_given_A_0_X))

sgm<-rnorm(n)
Y <- sin(3.1415926*X1)+(A_-0.5)*(X1+X2)+sgm
ini_E_Y_given_A_star_1_X<-sin(3.1415926*X1)+(1-0.5)*(X1+X2)
ini_E_Y_given_A_star_0_X<-sin(3.1415926*X1)+(0-0.5)*(X1+X2)

# library(gplm)
data <- data.frame( X1 = X1, X2 = X2, Y = Y,Z=Z,A=A,A_star = A_)
data[,c('A','X1','X2','Y')]<-lapply(data[,c('A','X1','X2','Y')],as.numeric)
data_test<-data

#set 1
p0_A_star_1_given_X=runif(n);p0_A_star_0_given_X=1-p0_A_star_1_given_X
p0_Y_given_A_star_1_X=runif(n);p0_Y_given_A_star_0_X=runif(n)
p0_A_1_given_A_star_1_X=runif(n,0.5,1);p0_A_0_given_A_star_1_X=1-p0_A_1_given_A_star_1_X
p0_A_1_given_A_star_0_X=runif(n,0,0.5);p0_A_0_given_A_star_0_X=1-p0_A_1_given_A_star_0_X
p0_Z_1_given_A_star_1_X=runif(n,0,0.5);p0_Z_0_given_A_star_1_X=1-p0_Z_1_given_A_star_1_X
p0_Z_1_given_A_star_0_X=runif(n,0.5,1);p0_Z_0_given_A_star_0_X=1-p0_Z_1_given_A_star_0_X

kx<-x <- data[,c('X1','X2')]
a<-data[,c('A')]
# dyn.load("kernel_regression.so")
A_one=1
A_a<-A_one-A
A_a<-as.data.frame(A_a)
K <-kernels(A_a)
# result <- .C("r_kernels", A_a)
Ksum <- sum(K)
pa<-K/(max(K))#+min(K)
# for(i in 1:10){
kernel_y<-get_pre_y(as.matrix(x),a,as.matrix(kx))
colnames(kernel_y)<-c(colnames(x),'y_')
pr_A_1_given_X<-unlist(kernel_y[,ncol(kernel_y)])#0.3556
pr_A_0_given_X<-1-pr_A_1_given_X
pr_A_given_X<-ifelse(A==1,pr_A_1_given_X,pr_A_0_given_X)

z<-data[,c('Z')]
Z_one=1
Z_a<-Z_one-Z
Z_a<-as.data.frame(Z_a)
K <-kernels(Z_a)
Ksum <- sum(K)
pz<-K/(max(K))#+min(K)
kernel_y<-get_pre_y(as.matrix(x),z,as.matrix(kx))
colnames(kernel_y)<-c(colnames(x),'y_')

# kernel_y<-get_pre_y(x,z,kx)
pr_Z_1_given_X<-unlist(kernel_y[,ncol(kernel_y)])
# mean(pr_Z_1_given_X[Z==1]);mean(pr_Z_1_given_X[Z==0])
pr_Z_0_given_X<-1-pr_Z_1_given_X
pr_Z_given_X<-ifelse(Z==1,pr_Z_1_given_X,pr_Z_0_given_X)
Y=data[,c('Y')]
x<-kx<- data[,c('X1','X2')]
pr_Y_given_X<-c()
for(x_estimate in 1:length(Y)){
  # X_x <-  x_estimate - x
  if(x_estimate%%100==0) print(x_estimate)
  Y_estimate=Y[x_estimate]
  Y_y<-Y_estimate-Y
  Y_y<-as.data.frame(Y_y)
  # K <-kernels(Y_y)
  kxx<-ifelse(Y_y==0,1,0)
  results <-get_pre_y(as.matrix(x),kxx,as.matrix(kx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  
  # final <- do.call('c',results)
  results_<-c()
  results_<-rbind(results_,results)
  
  pr_Y_given_X<-c(pr_Y_given_X,unlist(results_[,ncol(results_)]))
  
}
pr_Y_given_X_save<-pr_Y_given_X
pp<-c()
for(i in 1:nrow(data)){
  # print(sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]))
  pp<-c(pp,sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]))
  # pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]<-pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]/sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))])
}
pr_Y_given_X_save<-pr_Y_given_X_save/mean(pp)
pr_Y_given_X__<-c()
for(x_estimate in 1:length(Y)){
  pr_Y_given_X_<-pr_Y_given_X_save[((x_estimate-1)*n+1):((x_estimate)*n)]
  pr_Y_given_X_<-pr_Y_given_X_[x_estimate]
  # print(pr_Y_given_X_);print(max(pr_Y_given_X_))
  pr_Y_given_X__<-c(pr_Y_given_X__,pr_Y_given_X_)
}

pr_Y_given_X<-pr_Y_given_X__

pr_A_S_1_given_X=p0_A_star_1_given_X
pr_A_S_0_given_X=p0_A_star_0_given_X
pr_Y_given_A_star_1_X=p0_Y_given_A_star_1_X
pr_Y_given_A_star_0_X=p0_Y_given_A_star_0_X
pr_A_given_A_star_1_X<-ifelse(A==1,p0_A_1_given_A_star_1_X,p0_A_0_given_A_star_1_X)
pr_Z_given_A_star_1_X<-ifelse(Z==1,p0_Z_1_given_A_star_1_X,p0_Z_0_given_A_star_1_X)
pr_A_given_A_star_0_X<-ifelse(A==1,p0_A_1_given_A_star_0_X,p0_A_0_given_A_star_0_X)
pr_Z_given_A_star_0_X<-ifelse(Z==1,p0_Z_1_given_A_star_0_X,p0_Z_0_given_A_star_0_X)
pi_save<-0
t=1;o=1
while(t > 0.001){
  o=o+1
  # print(o)
  pA1<-pr_A_S_1_given_X*pr_Y_given_A_star_1_X*pr_A_given_A_star_1_X*pr_Z_given_A_star_1_X
  pA0<-pr_A_S_0_given_X*pr_Y_given_A_star_0_X*pr_A_given_A_star_0_X*pr_Z_given_A_star_0_X
  
  pi<-pA1/(pA1+pA0)
  kx<-x <-data[,c('X1','X2')]
  kernel_y<-get_pre_y(as.matrix(x),pi,as.matrix(x))
  colnames(kernel_y)<-c(colnames(x),'y_')
  pr_A_S_1_given_X<-unlist(kernel_y[,ncol(kernel_y)])
  pr_A_S_0_given_X=1-pr_A_S_1_given_X
  kx<-x<-data[,c('Y','X1','X2')]
  kernel_y<-get_pre_y(as.matrix(x),pi,as.matrix(x))
  colnames(kernel_y)<-c(colnames(x),'y_')
  pr_A_S_1_given_Y_X<-unlist(kernel_y[,ncol(kernel_y)])
  pr_A_S_0_given_Y_X=1-pr_A_S_1_given_Y_X
  x<-data[data$A==0,c('X1','X2')]
  pi_<-pi[data$A==0]
  kx<-data[,c('X1','X2')]
  kernel_y<-get_pre_y(as.matrix(x),pi_,as.matrix(kx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  pr_A_S_1_given_A_0_X<-unlist(kernel_y[,ncol(kernel_y)])
  pr_A_S_0_given_A_0_X=1-pr_A_S_1_given_A_0_X
  x<-data[data$Z==1,c('X1','X2')]
  pi_<-pi[data$Z==1]
  kx<-data[,c('X1','X2')]
  kernel_y<-get_pre_y(as.matrix(x),pi_,as.matrix(kx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  pr_A_S_1_given_Z_1_X<-unlist(kernel_y[,ncol(kernel_y)])
  pr_A_S_0_given_Z_1_X=1-pr_A_S_1_given_Z_1_X
  pr_Y_given_A_star_0_X<-pr_Y_given_X*pr_A_S_0_given_Y_X/pr_A_S_0_given_X
  pr_Y_given_A_star_1_X<-pr_Y_given_X*pr_A_S_1_given_Y_X/pr_A_S_1_given_X
  pr_A_0_given_A_star_0_X<-pr_A_0_given_X*pr_A_S_0_given_A_0_X/pr_A_S_0_given_X
  pr_A_1_given_A_star_0_X=1-pr_A_0_given_A_star_0_X
  pr_A_given_A_star_0_X<-ifelse(A==1,pr_A_1_given_A_star_0_X,pr_A_0_given_A_star_0_X)
  pr_A_0_given_A_star_1_X<-pr_A_0_given_X*pr_A_S_1_given_A_0_X/pr_A_S_1_given_X
  pr_A_1_given_A_star_1_X=1-pr_A_0_given_A_star_1_X
  pr_A_given_A_star_1_X<-ifelse(A==1,pr_A_1_given_A_star_1_X,pr_A_0_given_A_star_1_X)
  pr_Z_1_given_A_star_0_X<-pr_Z_1_given_X*pr_A_S_0_given_Z_1_X/pr_A_S_0_given_X
  pr_Z_0_given_A_star_0_X=1-pr_Z_1_given_A_star_0_X
  pr_Z_given_A_star_0_X<-ifelse(Z==1,pr_Z_1_given_A_star_0_X,pr_Z_0_given_A_star_0_X)
  pr_Z_1_given_A_star_1_X<-(pr_Z_1_given_X*pr_A_S_1_given_Z_1_X)/pr_A_S_1_given_X
  pr_Z_0_given_A_star_1_X<-1-pr_Z_1_given_A_star_1_X
  pr_Z_given_A_star_1_X<-ifelse(Z==1,pr_Z_1_given_A_star_1_X,pr_Z_0_given_A_star_1_X)
  t=max(abs(pi-pi_save))
  pi_save<-pi
  # print(t)
  
}
kx<-c()
kx<-generate_kx(data[,c('Y','X1','X2')])
x<-data[,c('Y','X1','X2')]
t<-10
kernel_y_<-c()
for(i in 1:t){
  kxx<-kx[((i-1)*(nrow(kx)/t)+1):((i)*(nrow(kx)/t)),]
  kxx<-as.data.frame(kxx)
  kernel_y<-get_pre_y(as.matrix(x),pi,as.matrix(kxx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  
  kernel_y_<-rbind(kernel_y_,kernel_y)
  print(i)
}
p1<-unlist(kernel_y_[,ncol(kernel_y_)])
p0=1-p1
pr_Y_given_X<-pr_Y_given_X_save
# }
ii<-cbind(pr_Y_given_X,p1,p0)
nrow_data<-nrow(data)
li<-fast_vector_operations(as.matrix(ii),pr_A_S_1_given_X,pr_A_S_0_given_X,nrow_data)
pr_Y_given_A_star_1_X_all<-li$pr_Y_given_A_star_1_X_all
pr_Y_given_A_star_0_X_all<-li$pr_Y_given_A_star_0_X_all
E_Y_given_A_star_1_X<-c();E_Y_given_A_star_0_X<-c()

for(j in 1:(length(pr_Y_given_A_star_1_X_all)/nrow(data))){
  o_1<-pr_Y_given_A_star_1_X_all[((j-1)*nrow(data)+1):(j*nrow(data))]
  o_2<-pr_Y_given_A_star_0_X_all[((j-1)*nrow(data)+1):(j*nrow(data))]
  o_1<-o_1/sum(o_1)
  o_2<-o_2/sum(o_2)
  E_Y_given_A_star_1_X<-c(E_Y_given_A_star_1_X,sum(o_1*Y,na.rm = T))
  E_Y_given_A_star_0_X<-c(E_Y_given_A_star_0_X,sum(o_2*Y,na.rm = T))
  
}
phi1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)/((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))+E_Y_given_A_star_1_X
phi0<-(A-pr_A_1_given_A_star_1_X)*(Z-pr_Z_1_given_A_star_1_X)*(Y-E_Y_given_A_star_0_X)/((pr_A_1_given_A_star_0_X-pr_A_1_given_A_star_1_X)*(pr_Z_1_given_A_star_0_X-pr_Z_1_given_A_star_1_X)*(pr_A_S_0_given_X))+E_Y_given_A_star_0_X
numerator_1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)
denominator_1<-((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))
add_1<-E_Y_given_A_star_1_X
numerator_0<-(A-pr_A_1_given_A_star_1_X)*(Z-pr_Z_1_given_A_star_1_X)*(Y-E_Y_given_A_star_0_X)
denominator_0<-((pr_A_1_given_A_star_0_X-pr_A_1_given_A_star_1_X)*(pr_Z_1_given_A_star_0_X-pr_Z_1_given_A_star_1_X)*(pr_A_S_0_given_X))
add_0<-E_Y_given_A_star_0_X
phi_t = phi1-phi0
truth_Y_t<-(X1+X2)
X_t=data[,c('X1','X2')]
# Y<-phi
Y_t<-phi_t<-phi
# X_t<-X
A_t<-A;Z_t<-Z
save(data_test,truth_Y_t,X_t,A_t,Z_t,Y_t,phi_t,file='test_5000_con.RData')


#####hiden treatment forest#####
numerator_1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)
denominator_1<-((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))
add_1<-E_Y_given_A_star_1_X
numerator_0<-(A-pr_A_1_given_A_star_1_X)*(Z-pr_Z_1_given_A_star_1_X)*(Y-E_Y_given_A_star_0_X)
denominator_0<-((pr_A_1_given_A_star_0_X-pr_A_1_given_A_star_1_X)*(pr_Z_1_given_A_star_0_X-pr_Z_1_given_A_star_1_X)*(pr_A_S_0_given_X))
add_0<-E_Y_given_A_star_0_X

data_save<-data_train
num.trees= 4000
clusters= numeric(0)
sample.weights = NULL;equalize.cluster.weights = FALSE
samples.per.cluster<- validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
sample.fraction= 0.5;mtry= 8;min.node.size= 6;honesty= TRUE;honesty.fraction= 0.5;
honesty.prune.leaves= TRUE;alpha= 0.05;imbalance.penalty= 0;stabilize.splits= TRUE;
ci.group.size= 4;compute.oob.predictions = TRUE;num.threads= 0;seed=runif(1, 0, .Machine$integer.max)
tune.parameters = "none";
data <- create_train_matrices(X,X_t, outcome = A, instrument = rep(0,n),sample.weights = sample.weights,numerator_1=numerator_1,denominator_1=denominator_1,add_1=add_1,
                              numerator_0=numerator_0,denominator_0=denominator_0,add_0=add_0, 
                              outcome_t = A_t)
args <- list(num.trees = num.trees,
             clusters = clusters,
             samples.per.cluster = samples.per.cluster,
             sample.fraction = sample.fraction,
             mtry = mtry,
             min.node.size = min.node.size,
             honesty = honesty,
             honesty.fraction = honesty.fraction,
             honesty.prune.leaves = honesty.prune.leaves,
             alpha = alpha,
             imbalance.penalty = imbalance.penalty,
             ci.group.size = ci.group.size,
             compute.oob.predictions = compute.oob.predictions,
             num.threads = num.threads,
             seed = seed,
             legacy.seed = get_legacy_seed())
forest <- do.call.rcpp(regression_train, c(data, args))
# plot(forest$predictions,A)
W.hat <- forest$predictions
W.centered <- A_t - W.hat
W.centered <-W.centered*1.2
# W.centered<-rep(1,n)
num.trees = 3000;
sample.weights = NULL;
clusters = NULL;
equalize.cluster.weights = FALSE;
sample.fraction = 0.5;
# mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X));
mtry = 10;min.node.size = 7;honesty = TRUE;honesty.fraction = 0.5;
honesty.prune.leaves = TRUE;alpha = 0.05;imbalance.penalty = 0;ci.group.size = 5;
tune.parameters = "none";tune.num.trees = 50;tune.num.reps = 100;tune.num.draws = 1000;
compute.oob.predictions = TRUE;num.threads = NULL;seed = runif(1, 0, .Machine$integer.max)
validate_sample_weights(sample.weights, X);
Y <- validate_observations(Y, X);
clusters <- validate_clusters(clusters, X);
samples.per.cluster <- validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
num.threads <- validate_num_threads(num.threads)
data <- create_train_matrices(X,X_t, outcome = phi, instrument = W.centered,sample.weights = sample.weights,numerator_1=numerator_1,denominator_1=denominator_1,add_1=add_1,
                              numerator_0=numerator_0,denominator_0=denominator_0,add_0=add_0,
                              outcome_t= phi_t)
mtry=7
args <- list(num.trees = 2000,
             clusters = clusters,
             samples.per.cluster = samples.per.cluster,
             sample.fraction = sample.fraction,
             mtry = mtry,
             min.node.size = min.node.size,
             honesty = honesty,
             honesty.fraction = honesty.fraction,
             honesty.prune.leaves = honesty.prune.leaves,
             alpha = alpha,
             imbalance.penalty = imbalance.penalty,
             ci.group.size = ci.group.size,
             compute.oob.predictions = compute.oob.predictions,
             num.threads = num.threads,
             seed = seed,
             legacy.seed = get_legacy_seed())
forest <- do.call.rcpp(regression_train, c(data, args))
preds = forest
df = data.frame(predictions = preds$predictions,
                truth = truth_Y_t,
                upper = preds$predictions + 1.96*sqrt(preds$variance.estimates),
                lower = preds$predictions - 1.96*sqrt(preds$variance.estimates))
percent_llf=0;avg_llf=0;
n<-nrow(df)
for(i in 1:n){
  xlow = ifelse(is.na(df$lower[i]),0,df$lower[i])
  xup = ifelse(is.na(df$upper[i]),0,df$upper[i])
  truthi = df$truth[i]
  if(xlow <= truthi && truthi <= xup){
    percent_llf = percent_llf + 1;
  }
  avg_llf = avg_llf + abs(xup - xlow)
}
percent_llf = percent_llf/n;percent_llf
mse<-sum((forest$predictions-truth_Y_t)^2)/n;mse
avg_llf = avg_llf/n;avg_llf
library(ggplot2)
pic2<-ggplot(data =df,aes(x = truth, y = predictions)) +
  geom_point(size= 0.01 ) +
  geom_abline(slope= 1 , intercept= 0 ,colour='#E41A1C') +
  theme_bw()+
  labs(x=paste('MSE:',round(mse,3)),y = "",title = paste('Hiden Treatment Forest'))+
  scale_x_continuous(limits = c(-0.2,2))+
  scale_y_continuous(limits = c(-0.2,2))+
  theme(plot.title=element_text(size=12),axis.title.x = element_text(size = 12), 
        axis.text.x = element_blank())
pic2

#####oracle hiden treatment forest#####
min.node.size = 8
data <- create_train_matrices(X,X_t, outcome = or_phi, instrument = W.centered,sample.weights = sample.weights,numerator_1=numerator_1,denominator_1=denominator_1,add_1=add_1,
                              numerator_0=numerator_0,denominator_0=denominator_0,add_0=add_0,
                              outcome_t= phi_t)
mtry=7
args <- list(num.trees = 2000,
             clusters = clusters,
             samples.per.cluster = samples.per.cluster,
             sample.fraction = sample.fraction,
             mtry = mtry,
             min.node.size = min.node.size,
             honesty = honesty,
             honesty.fraction = honesty.fraction,
             honesty.prune.leaves = honesty.prune.leaves,
             alpha = alpha,
             imbalance.penalty = imbalance.penalty,
             ci.group.size = ci.group.size,
             compute.oob.predictions = compute.oob.predictions,
             num.threads = num.threads,
             seed = seed,
             legacy.seed = get_legacy_seed())
forest <- do.call.rcpp(regression_train, c(data, args))
preds = forest
df = data.frame(predictions = preds$predictions,
                truth = truth_Y_t,
                upper = preds$predictions + 1.96*sqrt(preds$variance.estimates),
                lower = preds$predictions - 1.96*sqrt(preds$variance.estimates))
percent_llf=0;avg_llf=0;
n<-nrow(df)
for(i in 1:n){
  xlow = ifelse(is.na(df$lower[i]),0,df$lower[i])
  xup = ifelse(is.na(df$upper[i]),0,df$upper[i])
  truthi = df$truth[i]
  if(xlow <= truthi && truthi <= xup){
    percent_llf = percent_llf + 1;
  }
  avg_llf = avg_llf + abs(xup - xlow)
}
percent_llf = percent_llf/n;percent_llf
mse<-sum((forest$predictions-truth_Y_t)^2)/n;mse
avg_llf = avg_llf/n;avg_llf
library(ggplot2)
pic3<-ggplot(data =df,aes(x = truth, y = predictions)) +
  geom_point(size= 0.01 ) +
  geom_abline(slope= 1 , intercept= 0 ,colour='#E41A1C') +
  theme_bw()+
  scale_x_continuous(limits = c(-0.3,2))+
  scale_y_continuous(limits = c(-0.3,2))+
  labs(x=paste('MSE:',round(mse,3)),y = "",title = paste('Oracle Forest'))+
  scale_x_continuous(limits = c(-0.2,2))+
  scale_y_continuous(limits = c(-0.2,2))+
  theme(plot.title=element_text(size=12),
        axis.title.x = element_text(size = 12), 
        axis.text.x = element_blank(),axis.text.y = element_blank())
pic3

#####A replace A_star forest######
data<-data_train
kx<-x <- data[,c('X1','X2')]
a<-data[,c('A')]
A_one=1
A_a<-A_one-A
A_a<-as.data.frame(A_a)
K <-kernels(A_a)
Ksum <- sum(K)
pa<-K/(max(K))#+min(K)
kernel_y<-get_pre_y(as.matrix(x),pa,as.matrix(kx))
colnames(kernel_y)<-c(colnames(x),'y_')
pr_A_1_given_X<-unlist(kernel_y[,ncol(kernel_y)])#0.3556
pr_A_0_given_X<-1-pr_A_1_given_X
pr_A_given_X<-ifelse(A==1,pr_A_1_given_X,pr_A_0_given_X)

Y=data[,c('Y')]
x<-kx<- data[,c('X1','X2')]
pr_Y_given_X<-c()
for(x_estimate in 1:length(Y)){
  # X_x <-  x_estimate - x
  if(x_estimate%%2000==0) print(x_estimate)
  Y_estimate=Y[x_estimate]
  Y_y<-Y_estimate-Y
  Y_y<-as.data.frame(Y_y)
  # K <-kernels(Y_y)
  # Ksum <- sum(K)
  # kxx<-K/max(K)
  kxx<-ifelse(Y_y==0,1,0)
  results <-get_pre_y(as.matrix(x),kxx,as.matrix(kx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  
  # final <- do.call('c',results)
  results_<-c()
  results_<-rbind(results_,results)
  
  pr_Y_given_X<-c(pr_Y_given_X,unlist(results_[,ncol(results_)]))
  
}
pr_Y_given_X_save<-pr_Y_given_X
pp<-c()
for(i in 1:nrow(data)){
  # print(sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]))
  pp<-c(pp,sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]))
  # pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]<-pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]/sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))])
}
pr_Y_given_X_save<-pr_Y_given_X_save/mean(pp)

kx<-x <- data[,c('Y','X1','X2')]
a<-data[,c('A')]
A_one=1
A_a<-A_one-A
A_a<-as.data.frame(A_a)
K <-kernels(A_a)
Ksum <- sum(K)
pa<-K/(max(K))#+min(K)
kernel_y<-get_pre_y(as.matrix(x),pa,as.matrix(kx))
colnames(kernel_y)<-c(colnames(x),'y_')
pr_A_1_given_Y_X<-unlist(kernel_y[,ncol(kernel_y)])#0.3556
pi<-pr_A_1_given_Y_X
pr_A_0_given_Y_X<-1-pr_A_1_given_X

kx<-c()
kx<-generate_kx(data[,c('Y','X1','X2')])
# for(i in 1:nrow(data[,c('Y','X1','X2')])){
#   kx<-rbind(kx,cbind(data[rep(i,each=nrow(data)),c('X1','X2')],rep(data$Y)))
# }
x<-data[,c('Y','X1','X2')]
kx<-c()
kx<-generate_kx(data[,c('Y','X1','X2')])

t<-10
kernel_y_<-c()
for(i in 1:t){
  kxx<-kx[((i-1)*(nrow(kx)/t)+1):((i)*(nrow(kx)/t)),]
  kxx<-as.data.frame(kxx)
  kernel_y<-get_pre_y(as.matrix(x),pi,as.matrix(kxx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  
  kernel_y_<-rbind(kernel_y_,kernel_y)
  print(i)
}
p1<-unlist(kernel_y_[,ncol(kernel_y_)])
p0=1-p1
pr_Y_given_X<-pr_Y_given_X_save
ii<-cbind(pr_Y_given_X,p1,p0)
nrow_data<-nrow(data)
li<-fast_vector_operations(as.matrix(ii),pr_A_1_given_X,pr_A_0_given_X,nrow_data)
pr_Y_given_A_star_1_X_all<-li$pr_Y_given_A_star_1_X_all
pr_Y_given_A_star_0_X_all<-li$pr_Y_given_A_star_0_X_all
E_Y_given_A_star_1_X<-c();E_Y_given_A_star_0_X<-c()

for(j in 1:(length(pr_Y_given_A_star_1_X_all)/nrow(data))){
  o_1<-pr_Y_given_A_star_1_X_all[((j-1)*nrow(data)+1):(j*nrow(data))]
  o_2<-pr_Y_given_A_star_0_X_all[((j-1)*nrow(data)+1):(j*nrow(data))]
  o_1<-o_1/sum(o_1)
  o_2<-o_2/sum(o_2)
  E_Y_given_A_star_1_X<-c(E_Y_given_A_star_1_X,sum(o_1*Y,na.rm = T))
  E_Y_given_A_star_0_X<-c(E_Y_given_A_star_0_X,sum(o_2*Y,na.rm = T))
  
}
# phi1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)/((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))+E_Y_given_A_star_1_X
data<-data_train
A<-data[,c('A')];Y=data[,c('Y')]

phi1<-(A) *(Y-E_Y_given_A_star_1_X)/((pr_A_1_given_X))+E_Y_given_A_star_1_X
phi0<-(1-A)*(Y-E_Y_given_A_star_0_X)/((pr_A_0_given_X))+E_Y_given_A_star_0_X
phi<-phi1-phi0
numerator_1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)
denominator_1<-((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))
add_1<-E_Y_given_A_star_1_X
numerator_0<-(A-pr_A_1_given_A_star_1_X)*(Z-pr_Z_1_given_A_star_1_X)*(Y-E_Y_given_A_star_0_X)
denominator_0<-((pr_A_1_given_A_star_0_X-pr_A_1_given_A_star_1_X)*(pr_Z_1_given_A_star_0_X-pr_Z_1_given_A_star_1_X)*(pr_A_S_0_given_X))
add_0<-E_Y_given_A_star_0_X
num.trees= 4000
clusters= numeric(0)
sample.weights = NULL;equalize.cluster.weights = FALSE
samples.per.cluster<- validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
sample.fraction= 0.5;mtry= 8;min.node.size= 6;honesty= TRUE;honesty.fraction= 0.5;
honesty.prune.leaves= TRUE;alpha= 0.05;imbalance.penalty= 0;stabilize.splits= TRUE;
ci.group.size= 4;compute.oob.predictions = TRUE;num.threads= 0;seed=runif(1, 0, .Machine$integer.max)
tune.parameters = "none";
data <- create_train_matrices(X,X, outcome = A, instrument = rep(0,n),sample.weights = sample.weights,numerator_1=numerator_1,denominator_1=denominator_1,add_1=add_1,
                              numerator_0=numerator_0,denominator_0=denominator_0,add_0=add_0, 
                              outcome_t = A)
args <- list(num.trees = num.trees,
             clusters = clusters,
             samples.per.cluster = samples.per.cluster,
             sample.fraction = sample.fraction,
             mtry = mtry,
             min.node.size = min.node.size,
             honesty = honesty,
             honesty.fraction = honesty.fraction,
             honesty.prune.leaves = honesty.prune.leaves,
             alpha = alpha,
             imbalance.penalty = imbalance.penalty,
             ci.group.size = ci.group.size,
             compute.oob.predictions = compute.oob.predictions,
             num.threads = num.threads,
             seed = seed,
             legacy.seed = get_legacy_seed())
forest <- do.call.rcpp(regression_train, c(data, args))
# plot(forest$predictions,A)
W.hat <- forest$predictions
W.centered <- A - W.hat
W.centered <-W.centered*1.2
# W.centered<-rep(1,n)
num.trees = 3000;
sample.weights = NULL;
clusters = NULL;
equalize.cluster.weights = FALSE;
sample.fraction = 0.5;
# mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X));
mtry = 10;min.node.size = 7;honesty = TRUE;honesty.fraction = 0.5;
honesty.prune.leaves = TRUE;alpha = 0.05;imbalance.penalty = 0;ci.group.size = 5;
tune.parameters = "none";tune.num.trees = 50;tune.num.reps = 100;tune.num.draws = 1000;
compute.oob.predictions = TRUE;num.threads = NULL;seed = runif(1, 0, .Machine$integer.max)
validate_sample_weights(sample.weights, X);
Y <- validate_observations(Y, X);
clusters <- validate_clusters(clusters, X);
samples.per.cluster <- validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
num.threads <- validate_num_threads(num.threads)
data <- create_train_matrices(X,X, outcome = phi, instrument = W.centered,sample.weights = sample.weights,numerator_1=numerator_1,denominator_1=denominator_1,add_1=add_1,
                              numerator_0=numerator_0,denominator_0=denominator_0,add_0=add_0,
                              outcome_t= phi)
mtry=7
args <- list(num.trees = 2000,
             clusters = clusters,
             samples.per.cluster = samples.per.cluster,
             sample.fraction = sample.fraction,
             mtry = mtry,
             min.node.size = min.node.size,
             honesty = honesty,
             honesty.fraction = honesty.fraction,
             honesty.prune.leaves = honesty.prune.leaves,
             alpha = alpha,
             imbalance.penalty = imbalance.penalty,
             ci.group.size = ci.group.size,
             compute.oob.predictions = compute.oob.predictions,
             num.threads = num.threads,
             seed = seed,
             legacy.seed = get_legacy_seed())
forest <- do.call.rcpp(regression_train, c(data, args))
preds = forest
df = data.frame(predictions = preds$predictions,
                truth = truth_Y,
                upper = preds$predictions + 1.96*sqrt(preds$variance.estimates),
                lower = preds$predictions - 1.96*sqrt(preds$variance.estimates))
percent_llf=0;avg_llf=0;
n<-nrow(df)
for(i in 1:n){
  xlow = ifelse(is.na(df$lower[i]),0,df$lower[i])
  xup = ifelse(is.na(df$upper[i]),0,df$upper[i])
  truthi = df$truth[i]
  if(xlow <= truthi && truthi <= xup){
    percent_llf = percent_llf + 1;
  }
  avg_llf = avg_llf + abs(xup - xlow)
}
percent_llf = percent_llf/n;percent_llf
mse<-sum((forest$predictions-truth_Y)^2)/n;mse
avg_llf = avg_llf/n;avg_llf
library(ggplot2)
pic4<-ggplot(data =df,aes(x = truth, y = predictions)) +
  geom_point(size= 0.01 ) +
  geom_abline(slope= 1 , intercept= 0 ,colour='#E41A1C') +
  theme_bw()+
  labs(x=paste('MSE:',round(mse,3)),y = "",title = paste('Surrogate Observed Forest'))+
  scale_x_continuous(limits = c(-0.2,2))+
  scale_y_continuous(limits = c(-0.2,2))+
  theme(axis.title.x = element_text(size = 12), 
        plot.title=element_text(size=12))
pic4

#####Z replace A_star forest######
data<-data_train
kx<-x <- data[,c('X1','X2')]
Z<-data[,c('Z')]
Z_one=1
Z_a<-Z_one-Z
Z_a<-as.data.frame(Z_a)
K <-kernels(Z_a)
Ksum <- sum(K)
pa<-K/(max(K))#+min(K)
kernel_y<-get_pre_y(as.matrix(x),pa,as.matrix(kx))
colnames(kernel_y)<-c(colnames(x),'y_')
pr_Z_1_given_X<-unlist(kernel_y[,ncol(kernel_y)])#0.3556
pr_Z_0_given_X<-1-pr_Z_1_given_X
pr_Z_given_X<-ifelse(Z==1,pr_Z_1_given_X,pr_Z_0_given_X)

Y=data[,c('Y')]
x<-kx<- data[,c('X1','X2')]
pr_Y_given_X<-c()
for(x_estimate in 1:length(Y)){
  # X_x <-  x_estimate - x
  if(x_estimate%%2000==0) print(x_estimate)
  Y_estimate=Y[x_estimate]
  Y_y<-Y_estimate-Y
  Y_y<-as.data.frame(Y_y)
  # K <-kernels(Y_y)
  # Ksum <- sum(K)
  # kxx<-K/max(K)
  kxx<-ifelse(Y_y==0,1,0)
  results <-get_pre_y(as.matrix(x),kxx,as.matrix(kx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  
  # final <- do.call('c',results)
  results_<-c()
  results_<-rbind(results_,results)
  
  pr_Y_given_X<-c(pr_Y_given_X,unlist(results_[,ncol(results_)]))
  
}
pr_Y_given_X_save<-pr_Y_given_X
pp<-c()
for(i in 1:nrow(data)){
  # print(sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]))
  pp<-c(pp,sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]))
  # pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]<-pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))]/sum(pr_Y_given_X_save[seq(i,length(pr_Y_given_X_save),nrow(data))])
}
pr_Y_given_X_save<-pr_Y_given_X_save/mean(pp)

kx<-x <- data[,c('Y','X1','X2')]
Z<-data[,c('Z')]
Z_one=1
Z_a<-Z_one-Z
Z_a<-as.data.frame(Z_a)
K <-kernels(Z_a)
Ksum <- sum(K)
pa<-K/(max(K))#+min(K)
kernel_y<-get_pre_y(as.matrix(x),pa,as.matrix(kx))
colnames(kernel_y)<-c(colnames(x),'y_')
pr_Z_1_given_Y_X<-unlist(kernel_y[,ncol(kernel_y)])#0.3556
pi<-pr_Z_1_given_Y_X
pr_Z_0_given_Y_X<-1-pr_Z_1_given_X

kx<-c()
kx<-generate_kx(data[,c('Y','X1','X2')])
# for(i in 1:nrow(data[,c('Y','X1','X2')])){
#   kx<-rbind(kx,cbind(data[rep(i,each=nrow(data)),c('X1','X2')],rep(data$Y)))
# }
x<-data[,c('Y','X1','X2')]
kx<-c()
kx<-generate_kx(data[,c('Y','X1','X2')])

t<-10
kernel_y_<-c()
for(i in 1:t){
  kxx<-kx[((i-1)*(nrow(kx)/t)+1):((i)*(nrow(kx)/t)),]
  kxx<-as.data.frame(kxx)
  kernel_y<-get_pre_y(as.matrix(x),pi,as.matrix(kxx))
  colnames(kernel_y)<-c(colnames(x),'y_')
  
  kernel_y_<-rbind(kernel_y_,kernel_y)
  print(i)
}
p1<-unlist(kernel_y_[,ncol(kernel_y_)])
p0=1-p1
pr_Y_given_X<-pr_Y_given_X_save
ii<-cbind(pr_Y_given_X,p1,p0)
nrow_data<-nrow(data)
li<-fast_vector_operations(as.matrix(ii),pr_Z_1_given_X,pr_Z_0_given_X,nrow_data)
pr_Y_given_Z_star_1_X_all<-li$pr_Y_given_A_star_1_X_all
pr_Y_given_Z_star_0_X_all<-li$pr_Y_given_A_star_0_X_all
E_Y_given_A_star_1_X<-c();E_Y_given_A_star_0_X<-c()

for(j in 1:(length(pr_Y_given_A_star_1_X_all)/nrow(data))){
  o_1<-pr_Y_given_A_star_1_X_all[((j-1)*nrow(data)+1):(j*nrow(data))]
  o_2<-pr_Y_given_A_star_0_X_all[((j-1)*nrow(data)+1):(j*nrow(data))]
  o_1<-o_1/sum(o_1)
  o_2<-o_2/sum(o_2)
  E_Y_given_A_star_1_X<-c(E_Y_given_A_star_1_X,sum(o_1*Y,na.rm = T))
  E_Y_given_A_star_0_X<-c(E_Y_given_A_star_0_X,sum(o_2*Y,na.rm = T))
  
}
# phi1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)/((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))+E_Y_given_A_star_1_X
data<-data_train
Z<-data[,c('Z')];Y=data[,c('Y')]

phi1<-(Z) *(Y-E_Y_given_A_star_1_X)/((pr_Z_1_given_X))+E_Y_given_A_star_1_X
phi0<-(1-Z)*(Y-E_Y_given_A_star_0_X)/((pr_Z_0_given_X))+E_Y_given_A_star_0_X
phi<-phi1-phi0
numerator_1<-(A-pr_A_1_given_A_star_0_X)*(Z-pr_Z_1_given_A_star_0_X)*(Y-E_Y_given_A_star_1_X)
denominator_1<-((pr_A_1_given_A_star_1_X-pr_A_1_given_A_star_0_X)*(pr_Z_1_given_A_star_1_X-pr_Z_1_given_A_star_0_X)*(pr_A_S_1_given_X))
add_1<-E_Y_given_A_star_1_X
numerator_0<-(A-pr_A_1_given_A_star_1_X)*(Z-pr_Z_1_given_A_star_1_X)*(Y-E_Y_given_A_star_0_X)
denominator_0<-((pr_A_1_given_A_star_0_X-pr_A_1_given_A_star_1_X)*(pr_Z_1_given_A_star_0_X-pr_Z_1_given_A_star_1_X)*(pr_A_S_0_given_X))
add_0<-E_Y_given_A_star_0_X
num.trees= 4000
clusters= numeric(0)
sample.weights = NULL;equalize.cluster.weights = FALSE
samples.per.cluster<- validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
sample.fraction= 0.5;mtry= 8;min.node.size= 6;honesty= TRUE;honesty.fraction= 0.5;
honesty.prune.leaves= TRUE;alpha= 0.05;imbalance.penalty= 0;stabilize.splits= TRUE;
ci.group.size= 4;compute.oob.predictions = TRUE;num.threads= 0;seed=runif(1, 0, .Machine$integer.max)
tune.parameters = "none";
data <- create_train_matrices(X,X, outcome = Z, instrument = rep(0,n),sample.weights = sample.weights,numerator_1=numerator_1,denominator_1=denominator_1,add_1=add_1,
                              numerator_0=numerator_0,denominator_0=denominator_0,add_0=add_0, 
                              outcome_t = Z)
args <- list(num.trees = num.trees,
             clusters = clusters,
             samples.per.cluster = samples.per.cluster,
             sample.fraction = sample.fraction,
             mtry = mtry,
             min.node.size = min.node.size,
             honesty = honesty,
             honesty.fraction = honesty.fraction,
             honesty.prune.leaves = honesty.prune.leaves,
             alpha = alpha,
             imbalance.penalty = imbalance.penalty,
             ci.group.size = ci.group.size,
             compute.oob.predictions = compute.oob.predictions,
             num.threads = num.threads,
             seed = seed,
             legacy.seed = get_legacy_seed())
forest <- do.call.rcpp(regression_train, c(data, args))
# plot(forest$predictions,A)
W.hat <- forest$predictions
W.centered <- Z - W.hat
W.centered <-W.centered*1.2
# W.centered<-rep(1,n)
num.trees = 3000;
sample.weights = NULL;
clusters = NULL;
equalize.cluster.weights = FALSE;
sample.fraction = 0.5;
# mtry = min(ceiling(sqrt(ncol(X)) + 20), ncol(X));
mtry = 10;min.node.size = 7;honesty = TRUE;honesty.fraction = 0.5;
honesty.prune.leaves = TRUE;alpha = 0.05;imbalance.penalty = 0;ci.group.size = 5;
tune.parameters = "none";tune.num.trees = 50;tune.num.reps = 100;tune.num.draws = 1000;
compute.oob.predictions = TRUE;num.threads = NULL;seed = runif(1, 0, .Machine$integer.max)
validate_sample_weights(sample.weights, X);
Y <- validate_observations(Y, X);
clusters <- validate_clusters(clusters, X);
samples.per.cluster <- validate_equalize_cluster_weights(equalize.cluster.weights, clusters, sample.weights)
num.threads <- validate_num_threads(num.threads)
data <- create_train_matrices(X,X, outcome = phi, instrument = W.centered,sample.weights = sample.weights,numerator_1=numerator_1,denominator_1=denominator_1,add_1=add_1,
                              numerator_0=numerator_0,denominator_0=denominator_0,add_0=add_0,
                              outcome_t= phi)
mtry=7
args <- list(num.trees = 2000,
             clusters = clusters,
             samples.per.cluster = samples.per.cluster,
             sample.fraction = sample.fraction,
             mtry = mtry,
             min.node.size = min.node.size,
             honesty = honesty,
             honesty.fraction = honesty.fraction,
             honesty.prune.leaves = honesty.prune.leaves,
             alpha = alpha,
             imbalance.penalty = imbalance.penalty,
             ci.group.size = ci.group.size,
             compute.oob.predictions = compute.oob.predictions,
             num.threads = num.threads,
             seed = seed,
             legacy.seed = get_legacy_seed())
forest <- do.call.rcpp(regression_train, c(data, args))
preds = forest
df = data.frame(predictions = preds$predictions,
                truth = truth_Y,
                upper = preds$predictions + 1.96*sqrt(preds$variance.estimates),
                lower = preds$predictions - 1.96*sqrt(preds$variance.estimates))
percent_llf=0;avg_llf=0;
n<-nrow(df)
for(i in 1:n){
  xlow = ifelse(is.na(df$lower[i]),0,df$lower[i])
  xup = ifelse(is.na(df$upper[i]),0,df$upper[i])
  truthi = df$truth[i]
  if(xlow <= truthi && truthi <= xup){
    percent_llf = percent_llf + 1;
  }
  avg_llf = avg_llf + abs(xup - xlow)
}
percent_llf = percent_llf/n;percent_llf
mse<-sum((forest$predictions-truth_Y)^2)/n;mse
avg_llf = avg_llf/n;avg_llf
library(ggplot2)
pic5<-ggplot(data =df,aes(x = truth, y = predictions)) +
  geom_point(size= 0.005 ) +
  geom_abline(slope= 1 , intercept= 0 ,colour='#E41A1C') +
  theme_bw()+
  labs(x=paste('MSE:',round(mse,3)),y = "",title = paste('Proxy Observed Forest'))+
  scale_x_continuous(limits = c(-0.2,2))+
  # scale_y_continuous(limits = c(-0.2,2))+
  theme(axis.title.x = element_text(size = 12), 
        plot.title=element_text(size=12))
pic5

library(patchwork)
combined_plot <- (pic2 + pic3) / (pic4 + pic5)+ 
  plot_annotation(caption  = "True HTE") & 
  theme(plot.caption = element_text(hjust = 0.5, size = 18))
wrap_elements(combined_plot) +
  labs(tag = "Prediction HTE") +
  theme(
    plot.tag = element_text(size = 18, angle = 90),
    plot.tag.position = "left"
  )
