library(MASS)
library(kernlab)
library(foreach)
library(doParallel)
require(tidyverse)
require(ggplot2)
library(randomForest)
library(ggsci)
library(ggpubr)


setwd(dirname(rstudioapi::getSourceEditorContext()$path))



source("Functions.R")
source("AlgorithmClass.R")




d=10#dimension of covariates
alpha=0.1#target FDR

beta=c(0,1,2,3,0,0,0,0,0,0)
n_cal=1200
n_test=1200
n_train=1200
n_rest=n_cal+n_train

n=n_train+n_cal+n_test

algo_array=c(new('RF1'),new('RF2'))#Different cases
Thresholding_type_array=c("test","exchange","mean")#Selection rules
Thresholding_value=0.7



ns=100#replication times


cl = makeCluster(10)#parallel implementation
registerDoParallel(cl)
result<-foreach(iter=1:ns,.combine="rbind",.packages = c('MASS',"randomForest","kernlab"))%dopar% {
  info<-data.frame()

  for (algo in algo_array) {
    for(sigma in seq(0.1,1,0.1)){
      X=matrix(runif(n*d,-1,1),nrow=n,ncol=d)
      data=DataGen(algo,X,beta,sigma)
      lambda=Tuning(algo,X_train,Y_train)
      
      ## Generate the data by the cases
      datawork=DataSplit(data,n,n_test,n_cal,n_rest)
      data_train=datawork$data_train
      data_cal=datawork$data_cal
      data_rest=datawork$data_rest
      data_test=datawork$data_test
      
      X_train=as.matrix(data_train[colnames(data_train)[-d-1]])
      Y_train=as.matrix(data_train$y)
      X_cal=as.matrix(data_cal[colnames(data_cal)[-d-1]])
      Y_cal=as.matrix(data_cal$y)
      
      X_rest=as.matrix(data_rest[colnames(data_rest)[-d-1]])
      Y_rest=as.matrix(data_rest$y)
      X_test=as.matrix(data_test[colnames(data_test)[-d-1]])
      Y_test=as.matrix(data_test$y)
      
      
      
      model_train=fitting(algo,X_train,Y_train,lambda = lambda)
      
      Y_cal_hat=predict(model_train,as.data.frame(x=X_cal))
      Y_test_hat=predict(model_train,as.data.frame(x=X_test))
      Y_rest_hat=predict(model_train,as.data.frame(x=X_rest))
      
      # choose selection score by cases
      if(algo@name=='SVM'){
        T_test=X_test[,3]
        T_cal=X_cal[,3]
      }else if(algo@name=='RF1'){
        T_test=X_test[,1]
        T_cal=X_cal[,1]
      }else if(algo@name=='RF2'){
        T_test=X_test[,3]
        T_cal=X_cal[,3]
      }
      
      
      for (Thresholding_type in Thresholding_type_array) {
        if(Thresholding_type=="constant")
        {L=0.4}else if(Thresholding_type=="test")
        {L=quantile(T_test,Thresholding_value)}else if(Thresholding_type=="exchange")
        {L=quantile(c(T_test,T_cal),Thresholding_value)}else if(Thresholding_type=="mean")
        {L=mean(T_test)}
        
        if(algo@name=='RF1'||algo@name=='RF2'){
          b_0=quantile(Y_test,0.7)
        }else if(algo@name=='SVM'){b_0=0.5}
        
        V_cal=-Y_cal_hat
        V_test=-Y_test_hat
        

        
        Result_OMT=OMT(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha)
        info=rbind(info,list(FDP=Result_OMT$FDP,Power=Result_OMT$Power,Method="OMT",Thresholding_type=Thresholding_type,Noise=sigma,Setting=algo@name))
        
        #Bonferroni
        
        Result_Bonfer=OMT(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha,Bonfer = TRUE,BY=FALSE)
        info=rbind(info,list(FDP=Result_Bonfer$FDP,Power=Result_Bonfer$Power,Method="Bonferroni",Thresholding_type=Thresholding_type,Noise=sigma,Setting=algo@name))
        

        #BY
        Result_BY=OMT(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha,Bonfer = TRUE,BY=TRUE)
        info=rbind(info,list(FDP=Result_BY$FDP,Power=Result_BY$Power,Method="BY",Thresholding_type=Thresholding_type,Noise=sigma,Setting=algo@name))
        
        #SCOP
        Result_SCOP=SCOP(Y_cal,Y_cal_hat,Y_test_hat,b_0,T_test,T_cal,L,alpha)
        info=rbind(info,list(FDP=Result_SCOP$FDP,Power=Result_SCOP$Power,Method="SCOP",Thresholding_type=Thresholding_type,Noise=sigma,Setting=algo@name))
        
        
        
        #SCPV
        if(Thresholding_type=="mean"){
          Result_SCPV=MeanSCPV(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha)
        }else{
          Result_SCPV=SCPV(Y_cal,Y_test,b_0,T_cal,T_test,L,V_cal,V_test,alpha)
        }
        info=rbind(info,list(FDP=Result_SCPV$FDP,Power=Result_SCPV$Power,Method="SCPV",Thresholding_type=Thresholding_type,Noise=sigma,Setting=algo@name))
        
        
        
        
        

        
      }
      
    }
  }
  
  
  
  return(info)
}
stopCluster(cl)



result$Setting[result$Setting=="RF1"]="Case A"
result$Setting[result$Setting=="RF2"]="Case B"
result$Setting=factor(result$Setting,levels=c("Case A","Case B"))


result$Thresholding_type[result$Thresholding_type=="test"]="Quan"
result$Thresholding_type[result$Thresholding_type=="exchange"]="Exch"
result$Thresholding_type[result$Thresholding_type=="mean"]="Mean"


result$Method[result$Method=="Bonferroni"]="AMT(BH)"
result$Method[result$Method=="BY"]="AMT(BY)"
result$Method[result$Method=="SCOP"]="SCOP"
result=result[sapply(result$Method, function(x) any(x == c("SCPV","OMT","AMT(BH)","AMT(BY)","SCOP"))),]
result$Method=factor(result$Method,levels=c("SCPV","OMT","AMT(BH)","AMT(BY)","SCOP"))



tab=result%>%
  group_by(Method,Noise,Setting,Thresholding_type)%>%
  dplyr::summarize(FDR=mean(FDP),FDR_sd=sd(FDP),Power_sd=sd(Power),Power=mean(Power))

tab$FDR=tab$FDR*100
tab$Power=tab$Power*100
tab$FDR_sd=tab$FDR_sd*100
tab$Power_sd=tab$Power_sd*100



tab$Thresholding_type=factor(tab$Thresholding_type,levels=c("Exch","Quan","Mean"))


pointsize=1.5
linesize=0.5
textsize_title=12
textsize_axis=10





P1<-tab[is.element(tab$Noise,seq(0.1,0.9,0.2)),] %>%
  ## ȷ??x,y
  ggplot(aes(x = Noise, y = FDR, group = Method)) +theme_bw(base_size = 16)+
  geom_hline(yintercept = alpha*100,color="black",linetype="dashed",linewidth=1.2)+
  geom_line(aes(color=Method),size=linesize) + geom_point(aes(color=Method),size=pointsize)+scale_color_nejm(palette = c("default"), alpha = 1)+
  geom_ribbon(aes(ymin=FDR-FDR_sd/sqrt(ns),ymax=FDR+FDR_sd/sqrt(ns),fill=Method),alpha = 0.3,linetype = 1,color = NA)+
  scale_fill_manual(values=c("#BC3C29FF","#0072B5FF","#E18727FF","#20854EFF","#7876B1FF"))+
  scale_y_continuous(name="FDR(%)",position="left")+
  theme(axis.title.y=element_text(size = textsize_title),axis.text.y =element_text(size = textsize_axis))+
  scale_x_continuous(name = "Noise Strength",breaks = c(0.2,0.5,0.8))+
  theme(axis.title.x=element_text(size = textsize_title),axis.text.x=element_text(size = textsize_axis))+
  facet_grid(Setting~Thresholding_type ,scales="free")+theme_set(theme_light())+theme_bw()+
  theme(strip.text.x = element_text(size = textsize_axis),strip.text.y = element_text(size = textsize_axis))+
  theme(legend.position="top",legend.title=element_text(size=textsize_title),legend.text=element_text(size=textsize_title))



P2<-tab[is.element(tab$Noise,seq(0.1,0.9,0.2)),] %>%
  ## ȷ??x,y
  ggplot(aes(x = Noise, y = Power, group = Method)) +theme_bw(base_size = 16)+
  geom_line(aes(color=Method),size=linesize) + geom_point(aes(color=Method),size=pointsize)+scale_color_nejm(palette = c("default"), alpha = 1)+
  geom_ribbon(aes(ymin=Power-Power_sd/sqrt(ns),ymax=Power+Power_sd/sqrt(ns),fill=Method),alpha = 0.3,linetype = 1,color = NA)+
  scale_fill_manual(values=c("#BC3C29FF","#0072B5FF","#E18727FF","#20854EFF","#7876B1FF"))+
  scale_y_continuous(name="Power(%)",position="left")+theme(axis.title.y=element_text(size = textsize_title),axis.text.y =element_text(size = textsize_axis))+
  scale_x_continuous(name = "Noise Strength",breaks = c(0.2,0.5,0.8)) +
  theme(axis.title.x=element_text(size = textsize_title),axis.text.x=element_text(size = textsize_axis))+
  facet_grid(Setting~Thresholding_type ,scales="free")+theme_set(theme_light())+theme_bw()+
  theme(strip.text.x = element_text(size = textsize_axis),strip.text.y = element_text(size = textsize_axis))+theme(legend.position="top")+
  theme(legend.position="top",legend.title=element_text(size=textsize_title),legend.text=element_text(size=textsize_title))


P=ggarrange(P1,P2,common.legend=TRUE,ncol=2)

P












