library(kernlab)
library(KernSmooth)
library(MASS)
library("ks")
library(foreach)
library(glmnet)
library(randomForest)
library(doParallel)
require(tidyverse)
require(ggplot2)
library(ggpubr)
library(caret)
library(nnet)


setwd(dirname(rstudioapi::getSourceEditorContext()$path))

source("Functions.R")
source("AlgorithmClass.R")


##read and process the data
data0<-read.table("adult.csv",header=T,sep = ",")

data0[data0=="?"]=NA

data1<-na.omit(data0)

USindex=which(data1$native.country=="United-States")
data1=select(data1[USindex,],select=-c(native.country))


varcontinue <- c("age","fnlwgt","education.num","capital.gain","capital.loss","hours.per.week")  
colname=colnames(data1)
y=as.numeric(data1$income==factor(">50K",levels=c("<=50K",">50K")))
data1 <- cbind(lapply(data1[,varcontinue],function(x) as.numeric(as.character(x))),as.data.frame(lapply(data1[,setdiff(colname,varcontinue)],function(x) factor(x))))
dummy <- dummyVars(" ~ .", data=data1[,-length(colname)])

#perform one-hot encoding on data frame
data <- data.frame(predict(dummy, newdata=data1))

data$y=y


Number=dim(data)[1]
sr=1-sum(data$y==1)/Number

d=dim(data)[2]-1
n=2000
N=2000
n_train=round(n*0.5)
n_cal=round(n*0.5)

ns=100#replication times


alpha=0.1
Thresholding_value=0.7
Thresholding_type_array=c("test","exchange","mean")
sr=0.8 

cl = makeCluster(10)#parallel implementation
registerDoParallel(cl)
result<-foreach(iter=1:ns,.combine="rbind",.packages = c('MASS',"randomForest","kernlab","nnet"))%dopar% {
  info<-data.frame()
  
  ## Sampling the data
  Value=list(type="==A,R",v=0)
  Null=NullIndex(data$y,Value)
  Alter=setdiff(1:Number,Null)
  
  IndexSample=c(sample(Null,round((n+N)*sr),replace = FALSE),sample(Alter,n+N-round((n+N)*sr),replace = FALSE))
  newdata=data[sample(IndexSample,n+N,replace = FALSE),]
  
  
  datawork=DataSplit(newdata,n+N,N,n_cal,n)
  data_train=datawork$data_train
  
  data_cal=datawork$data_cal
  data_rest=datawork$data_rest
  data_test=datawork$data_test
  
  Null_cal=NullIndex(data_cal$y,Value)
  Null_rest=NullIndex(data_rest$y,Value)
  
  X_train=data_train[colnames(data_train)[-d-1]]
  Y_train=as.matrix(data_train$y)
  X_cal=data_cal[colnames(data_cal)[-d-1]]
  Y_cal=as.matrix(data_cal$y)
  
  X_rest=data_rest[colnames(data_rest)[-d-1]]
  Y_rest=as.matrix(data_rest$y)
  
  Null_test=NullIndex(data_test$y,Value)
  Alter_test=setdiff(1:length(data_test$y),Null_test)
  X_test=data_test[colnames(data_test)[-d-1]]
  Y_test=as.matrix(data_test$y)
  
  
  

  model=randomForest(y~.,data=data.frame(X_train,y=as.factor(Y_train)),ntree=500)
  

  Y_cal_hat=as.numeric(predict(model,X_cal,type="prob")[,2])
  Y_test_hat=as.numeric(predict(model,X_test,type="prob")[,2])

  
  T_test=X_test[,1]
  T_cal=X_cal[,1]
  

  b_0=0.5
  for (Thresholding_type in Thresholding_type_array) {
    if(Thresholding_type=="constant")
    {L=20000}else if(Thresholding_type=="test")
    {L=quantile(T_test,Thresholding_value)}else if(Thresholding_type=="exchange")
    {L=quantile(c(T_test,T_cal),Thresholding_value)}else if(Thresholding_type=="mean")
    {L=mean(T_test)}
    
    
    V_cal=-Y_cal_hat
    V_test=-Y_test_hat
    
    
    
    Result_OMT=OMT(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha)
    info=rbind(info,list(FDP=Result_OMT$FDP,Power=Result_OMT$Power,Method="OMT",Thresholding_type=Thresholding_type))
    
    #Bonferroni
    
    Result_Bonfer=OMT(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha,Bonfer = TRUE,BY=FALSE)
    info=rbind(info,list(FDP=Result_Bonfer$FDP,Power=Result_Bonfer$Power,Method="Bonferroni",Thresholding_type=Thresholding_type))
    
    
    
    
    #BY

    Result_BY=OMT(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha,Bonfer = TRUE,BY=TRUE)
    info=rbind(info,list(FDP=Result_BY$FDP,Power=Result_BY$Power,Method="BY",Thresholding_type=Thresholding_type))
    
    #SCPV
    if(Thresholding_type=="mean"){
      Result_SCPV=MeanSCPV(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha)
    }else{
      Result_SCPV=SCPV(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha)
    }
    info=rbind(info,list(FDP=Result_SCPV$FDP,Power=Result_SCPV$Power,Method="SCPV",Thresholding_type=Thresholding_type))
    
    
    
  }
  
  
  return(info)
}
stopCluster(cl)

tab=result%>%
  group_by(Method,Thresholding_type)%>%
  dplyr::summarize(FDR=mean(FDP),FDR_sd=sd(FDP),Power_sd=sd(Power),Power=mean(Power))
tab


