rm(list = ls());gc()
library(doSNOW)
library(randomForest)
library(grf)

progress <- function(nfin){
  cat(sprintf('%s: tasks completed: %d.\n', Sys.time(), nfin))
}

opts <- list(progress = progress)

dis_vec<-function(vec){
  n<-length(vec)
  dis_mat<-(matrix(rep(vec,each=n), nrow = n, ncol = n,byrow = T)-matrix(rep(vec,each=n), nrow = n, ncol = n))^2
  return(dis_mat)
}

dis_matrix<-function(mat){
  if(is.matrix(mat)){
    result_list<-apply(mat, 2, dis_vec,simplify=F)
    result_matrix<-Reduce("+",result_list)
    return(result_matrix)
  }else{
    result_matrix<-dis_vec(mat)
    return(result_matrix)
  }
}

BH<-function(p_k=p_k,alpha=alpha){
  rank_p<-sapply(1:length(p_k), function(x){sum(p_k[x]>=p_k)})
  if(sum(p_k<alpha*rank_p/length(p_k))==0){
    return(0)
  }else{
    return(max(rank_p[p_k<alpha*rank_p/length(p_k)]))
  }
}



LCP_detect<-function(cal_score=cal_score,test_score=test_score,h=0.01,alpha=0.1){
  R_tilde<-rep(NA,length(test_score))
  for (i in 1:length(test_score)) {
    cal_score_i<-c(cal_score,test_score[i])
    test_score_i<-test_score[-i]
    dis_mat_i<-dis_matrix(c(X_cal[,d],X_test[i,d],X_test[-i,d]))
    score_cal<-matrix(rep(cal_score_i,each=length(cal_score_i)),ncol = length(cal_score_i))
    weight_matrix<-exp(-dis_mat_i/h)
    weight_cal<-colSums(weight_matrix[1:length(cal_score_i),1:length(cal_score_i)]*(t(score_cal)<=score_cal))
    K_sum_cal<-colSums(weight_matrix[1:length(cal_score_i),1:length(cal_score_i)])
    K_cal<-weight_cal+(matrix(rep(test_score_i,each=length(cal_score_i)),ncol=length(test_score_i))>=matrix(rep(cal_score_i,times=length(test_score_i)),ncol = length(test_score_i)))*weight_matrix[1:length(cal_score_i),length(cal_score_i)+(1:length(test_score_i))]
    F_cal<-K_cal/(K_sum_cal+weight_matrix[1:length(cal_score_i),length(cal_score_i)+(1:length(test_score_i))])
    K_test<-(colSums((matrix(rep(test_score_i,each=length(cal_score_i)),ncol=length(test_score_i))<=matrix(rep(cal_score_i,times=length(test_score_i)),ncol = length(test_score_i)))*weight_matrix[1:length(cal_score_i),length(cal_score_i)+(1:length(test_score_i))])+1)
    F_test<-K_test/(colSums(weight_matrix[1:length(cal_score_i),length(cal_score_i)+(1:length(test_score_i))])+1)
    F_all<-rbind(F_cal,F_test)
    p_value<-apply(F_all, 2, function(x){sum(x<=x[length(x)])/length(x)})
    re<-which(p_value<=alpha*BH(c(0,p_value),alpha)/length(c(p_value,0)))
    R_tilde[i]<-length(re)
  }
  score_cal<-matrix(rep(cal_score,each=n_cal),ncol = n_cal)
  weight_matrix<-exp(-dis_mat/h)
  weight_cal<-colSums(weight_matrix[1:n_cal,1:n_cal]*(t(score_cal)<=score_cal))
  K_sum_cal<-colSums(weight_matrix[1:n_cal,1:n_cal])
  K_cal<-weight_cal+(matrix(rep(test_score,each=n_cal),ncol=n_test)>=matrix(rep(cal_score,times=n_test),ncol = n_test))*weight_matrix[1:n_cal,n_cal+(1:n_test)]
  F_cal<-K_cal/(K_sum_cal+weight_matrix[1:n_cal,n_cal+(1:n_test)])
  K_test<-(colSums((matrix(rep(test_score,each=n_cal),ncol=n_test)<=matrix(rep(cal_score,times=n_test),ncol = n_test))*weight_matrix[1:n_cal,n_cal+(1:n_test)])+1)
  F_test<-K_test/(colSums(weight_matrix[1:n_cal,n_cal+(1:n_test)])+1)
  F_all<-rbind(F_cal,F_test)
  p_value<-apply(F_all, 2, function(x){sum(x<=x[length(x)])/length(x)})
  BH_result<-which(p_value<alpha*R_tilde/length(p_value))
  if(length(BH_result)>=max(max(R_tilde[BH_result]),1)){
    detection_result<-BH_result
  }else{
    u<-runif(length(BH_result))
    p_til<-u*R_tilde[BH_result]/length(BH_result)
    detection_result<-BH_result[which(sapply(1:length(p_til), function(x){sum(p_til[x]>=p_til)})<BH(p_til,1))]
  }
  return(detection_result)
}


LCP_au_detect<-function(train_score=train_score,cal_score=cal_score,test_score=test_score,h_sel=h_sel,r1=r1,alpha=0.2){
  ra_two_class<-data.frame(X=c(X_train[,d],X_test[,d]),S=c(train_score,test_score))
  ra_tar<-c(rep(0,n_train),rep(1,n_test))
  ra_model<-glm(tar ~ ., data = data.frame(ra_two_class,tar=ra_tar), family = binomial)
  ra_predictions_test <- predict(ra_model, newdata = data.frame(X=X_test[,d],S=test_score), type = "response")
  ra_predictions_test<-as.double(ra_predictions_test)
  ra_test<-(1-ra_predictions_test)*(n_test)/(n_train*ra_predictions_test)
  score_all<-matrix(rep(c(cal_score,test_score),each=(n_cal+n_test)),ncol = (n_cal+n_test))
  sign_matrix<-matrix(as.numeric(t(score_all)<=score_all),ncol = (n_cal+n_test))
  rm(score_all)
  h_dis_nu<-matrix(NA,ncol = n_test,nrow = length(h_sel))
  R_con<-matrix(NA,ncol = n_test,nrow = length(h_sel))
  for (i in 1:length((h_sel))) {
    weight_matrix<-exp(-dis_mat/h_sel[i])
    eme<-weight_matrix*sign_matrix
    weight_cal<-rowSums(eme[,1:n_cal])
    weight_test<-eme[,-(1:n_cal)]
    ra_weight_test<-t(t(weight_test)*ra_test)
    ra_weight_sum<-rowSums(ra_weight_test)
    K_sum_cal<-rowSums(weight_matrix[,1:n_cal])
    K_sum_test<-weight_matrix[,-(1:n_cal)]
    ra_K_sum_test<-t(t(K_sum_test)*ra_test)
    ra_K_sum<-rowSums(ra_K_sum_test)
    rm(eme);gc()
    dis_num<-vector(,length = n_test)
    R_con_k<-vector(,length = n_test)
    for (k in 1:n_test) {
      nu_cal=r1*(weight_cal[1:n_cal]+weight_test[1:n_cal,k]+weight_test[1:n_cal,])[,-k]+(1-r1)*(ra_weight_sum[1:n_cal]-ra_weight_test[1:n_cal,k]-ra_weight_test[1:n_cal,])[,-k]
      nu_sel=r1*(weight_cal[n_cal+k]+weight_test[n_cal+k,]+weight_test[n_cal+k,k])[-k]+(1-r1)*(sapply(1:n_test, function(x){ra_weight_sum[n_cal+k]-ra_weight_test[n_cal+k,k]-ra_weight_test[n_cal+x,k]})[-k])
      nu_te=r1*((weight_cal[-(1:n_cal)]+diag(weight_test[-(1:n_cal),])+weight_test[-(1:n_cal),k]))[-k]+(1-r1)*((sapply(1:n_test,function(x){ra_weight_sum[n_cal+x]-ra_weight_test[n_cal+x,k]-ra_weight_test[n_cal+x,x]})[-k]))
      de_cal=r1*((K_sum_cal[1:n_cal]+K_sum_test[1:n_cal,k]+K_sum_test[1:n_cal,]))[,-k]+(1-r1)*(((ra_K_sum[1:n_cal]-ra_K_sum_test[1:n_cal,k]-ra_K_sum_test[1:n_cal,])[,-k]))
      de_sel=r1*((K_sum_cal[n_cal+k]+K_sum_test[n_cal+k,]+K_sum_test[n_cal+k,k]))[-k]+(1-r1)*((sapply(1:n_test, function(x){ra_K_sum[n_cal+k]-ra_K_sum_test[n_cal+k,k]-ra_K_sum_test[n_cal+x,k]})[-k]))
      de_te=r1*((K_sum_cal[-(1:n_cal)]+diag(K_sum_test[-(1:n_cal),])+K_sum_test[-(1:n_cal),k]))[-k]+(1-r1)*((sapply(1:n_test,function(x){ra_K_sum[n_cal+x]-ra_K_sum_test[n_cal+x,k]-ra_K_sum_test[n_cal+x,x]})[-k]))
      p_k_all<-rbind(nu_cal/de_cal,nu_sel/de_sel,nu_te/de_te)
      p_k<-apply(p_k_all, 2, function(x){sum(x[length(x)]>=x)/length(x)})
      p_k[is.na(p_k)]<-1
      dis_num[k]<-BH(p_k,alpha)
      R_con_k[k]<-BH(c(p_k,0),alpha)
    }
    h_dis_nu[i,]<-dis_num
    R_con[i,]<-R_con_k
  }
  R_selected<-sapply(1:n_test, function(x){R_con[,x][which(h_dis_nu[,x]==max(h_dis_nu[,x]))[1]]})
  h_selected<-apply(h_dis_nu, 2, function(x){h_sel[which(x==max(x))[1]]})
  p_k<-rep(NA,n_test)
  for (k in 1:n_test) {
    weight_matrix<-exp(-dis_mat/h_selected[k])
    eme<-weight_matrix*sign_matrix
    weight_cal<-rowSums(eme[,1:n_cal])
    weight_test<-eme[,-(1:n_cal)]
    ra_weight_test<-t(t(weight_test)*ra_test)
    ra_weight_sum<-rowSums(ra_weight_test)
    K_sum_cal<-rowSums(weight_matrix[,1:n_cal])
    K_sum_test<-weight_matrix[,-(1:n_cal)]
    ra_K_sum_test<-t(t(K_sum_test)*ra_test)
    ra_K_sum<-rowSums(ra_K_sum_test)
    rm(eme);gc()
    nu_cal<-r1*(weight_cal[1:n_cal]+weight_test[1:n_cal,k])+(1-r1)*(ra_weight_sum[1:n_cal]-ra_weight_test[1:n_cal,k])
    nu_test<-r1*((weight_cal[n_cal+k]+1))+(1-r1)*(ra_weight_sum[n_cal+k]-ra_weight_test[n_cal+k,k])
    de_cal<-r1*((K_sum_cal[1:n_cal]+K_sum_test[1:n_cal,k]))+(1-r1)*(ra_K_sum[1:n_cal]-ra_K_sum_test[1:n_cal,k])
    de_test<-r1*((K_sum_cal[n_cal+k]+1))+(1-r1)*((ra_K_sum[n_cal+k]-ra_K_sum_test[n_cal+k,k]))
    p_k_all<-c(nu_cal/de_cal,nu_test/de_test)
    p_k[k]<-sum(p_k_all[length(p_k_all)]>=p_k_all)/length(p_k_all)
  }
  p_k[is.na(p_k)]<-1
  BH_result<-which(p_k<=alpha*R_selected/length(p_k))
  if(length(BH_result)>=max(max(R_selected[BH_result]),1)){
    detection_result<-BH_result
  }else{
    u<-runif(length(BH_result))
    p_til<-u*R_selected[BH_result]/length(BH_result)
    detection_result<-BH_result[which(sapply(1:length(p_til), function(x){sum(p_til[x]>=p_til)})<BH(p_til,1))]
  }

  return(detection_result)
}

RLCP<-function(cal_score,test_score,h,alpha){
  s_sam <- matrix(0, ncol = 1, nrow = n_test)
  for (i in 1:n_test) {
    s_sam[i,] <- mvrnorm(1, as.numeric(X_test[i, d:d]), (h^2)*diag(1))
  }
  
  weight <- matrix(0, nrow = n_test, ncol = n_cal+1)
  for (j in 1:n_cal) {
    diffmat <- matrix(0, nrow = n_test, ncol = 2)
    for (k in d:d) {
      diffmat[, 1] <- s_sam[, 1] -X_cal[j, k]
    }
    weight[, j] <- exp(-apply(diffmat^2, 1, sum)/(h^2))
  }
  diffmat <- matrix(0, nrow = n_test, ncol = 1)
  for (k in 1:1) {
    diffmat[, k] <- s_sam[, k] - X_test[, d]
  }
  weight[, n_cal+1] <- exp(-apply(diffmat^2, 1, sum)/(h^2))
  IndQR <- matrix(1, nrow = n_test, ncol = n_cal+1)
  for (j in 1:n_cal) {
    IndQR[,j] <- ifelse(test_score<cal_score[j], 1, 0)
  }
  IndQR[, n_cal+1] <- runif(n_test)
  WQR <- weight*IndQR
  pvalues <- (apply(WQR, 1, sum))/(apply(weight, 1, sum))
  pvalues[is.na(pvalues)] <- 1
  Rtild <- rep(0, n_test)
  unnorm_p <- apply(WQR, 1, sum)
  sum_weight <- apply(weight, 1, sum)
  for (j in 1:n_test) {
    pvalues_j <- (unnorm_p - WQR[, n_cal+1] + weight[, n_cal+1]*ifelse(test_score<=test_score[j], 1, 0))/sum_weight
    pvalues_j[is.na(pvalues_j)] <- 1
    pvalues_j[j] <- 0
    rej_j <- sort(pvalues_j)<((1:length(pvalues_j))/length(pvalues_j))*alpha
    rejnum_j <- max(which(rej_j==T))
    Rtild[j] <- rejnum_j
  }
  S <- alpha*Rtild/n_test
  R1 <- which(pvalues<=S)
  xi <- runif(n_test)
  R <- 0
  for (r in 1:length(R1)) {
    if(sum(ifelse(pvalues<=S&xi*Rtild<=r, 1, 0))>=r){
      R <- r
    }
  }
  reject <- which(pvalues<=S&xi*Rtild<=R)
  return(reject)
}

rin<-function(x){
  return(5/(1+abs(x)))
}

rout<-function(x){
  return(5+5/(1+abs(x)))
}


d=10
w_1=0.2 #novelty rate
r1=0.8 #weight parameter
alpha<-0.2
beta<-c(0.5,-0.5,0.5,-0.5,0.5,0,0,0,0)


FDR<-matrix(,nrow = 5,ncol = 6)
colnames(FDR)<-c("CP","CQ","LCP","LCQ","ALCP","ALCQ")
Power<-matrix(,nrow = 5,ncol = 6)
colnames(Power)<-c("CP","CQ","LCP","LCQ","ALCP","ALCQ")
results<-list()
for (k in 1:3) {
  n_1=1000;n_2=250*(1+k)
  ou_nu=floor(n_2*w_1)
  cl <- makeCluster(50, type = "SOCK")
  registerDoSNOW(cl)
  times<-200
  results[[k]] <- foreach(i = 1:times,.packages = c("grf","randomForest","MASS"), .combine = "cbind",.multicombine = TRUE,.options.snow = opts) %dopar% {
    X_1<-cbind(matrix(runif(n_1*(d-1),min = -1,max = 1),ncol = d-1),rnorm(n_1,0,1))
    ep_1<-rnorm(n_1)
    Y_1<-X_1[,1:(d-1)]%*%beta+rin(X_1[,d])*ep_1
    X_2<-cbind(matrix(runif(n_2*(d-1),min = -1,max = 1),ncol = d-1),rnorm(n_2,0,1))
    ep_2<-rnorm(n_2)
    Y_2<-X_2[,1:(d-1)]%*%beta+c(rout(X_2[1:ou_nu,d])*sample(c(-1,1),ou_nu,replace = T),rin(X_2[(ou_nu+1):n_2,d])*ep_2[(ou_nu+1):n_2])
    n_train=floor(n_1*0.5);n_cal=n_1-floor(n_1*0.5)
    train_id=sample(1:n_1,n_train)
    X_train=X_1[train_id,];X_cal=X_1[-train_id,]
    Y_train=Y_1[train_id];Y_cal=Y_1[-train_id]
    X_test<-X_2;Y_test<-Y_2
    n_cal=length(Y_cal);n_test=length(Y_test);n_train<-length(Y_train)
    dis_mat<-dis_matrix(c(X_cal[,d],X_test[,d]))
    rf_model <- randomForest(Y~.,data = data.frame(X=X_train,Y=Y_train))
    cal_score<-abs(Y_cal-predict(rf_model, newdata = data.frame(X=X_cal)))
    train_score<-abs(Y_train-predict(rf_model, newdata = data.frame(X=X_train)))
    test_score<-abs(Y_test-predict(rf_model, newdata = data.frame(X=X_test)))
    modelQR<-quantile_forest(X_train,Y_train,quantiles = c(0.1,0.9))
    quan_cal<-predict(modelQR,X_cal)$predictions
    quan_cal_score<-apply(cbind(Y_cal-quan_cal[,2],quan_cal[,1]-Y_cal), 1,max)
    quan_train<-predict(modelQR,X_train)$predictions
    quan_train_score<-apply(cbind(Y_train-quan_train[,2],quan_train[,1]-Y_train), 1,max)
    quan_test<-predict(modelQR,X_test)$predictions
    quan_test_score<-apply(cbind(Y_test-quan_test[,2],quan_test[,1]-Y_test), 1,max)
    #CP
    p_re_CP<-sapply(test_score,function(x){(sum(x<=cal_score)+1)/(n_cal+1)})
    p_quan_CP<-sapply(quan_test_score, function(x){(sum(x<=quan_cal_score)+1)/(n_cal+1)})
    result_re_CP<-which(p_re_CP<=BH(p_re_CP,alpha)*alpha/length(p_re_CP))
    result_quan_CP<-which(p_re_CP<=BH(p_quan_CP,alpha)*alpha/length(p_quan_CP))
    #LCP
    d_con=1
    result_re_RLCP<-RLCP(cal_score,test_score,h=sqrt(n_cal^{-1/(d_con+2)}),alpha = alpha)
    result_quan_RLCP<-RLCP(quan_cal_score,quan_test_score,h=sqrt(n_cal^{-1/(d_con+2)}),alpha = alpha)
    #ALCP
    h_sel=(10^(-2:2))*(n_cal^{-1/(d_con+2)})
    result_re_ALCP<-LCP_au_detect(train_score,cal_score,test_score,h_sel,r1=r1,alpha = alpha)
    result_quan_ALCP<-LCP_au_detect(quan_train_score,quan_cal_score,quan_test_score,h_sel,r1=r1,alpha = alpha)
    return(list(result_re_CP=result_re_CP,result_quan_CP=result_quan_CP,
                result_re_RLCP=result_re_RLCP,result_quan_RLCP=result_quan_RLCP,
                result_re_ALCP=result_re_ALCP,result_quan_ALCP=result_quan_ALCP))
  }
  save(results,file = "setting1_ex2_result_1000.rds")
  stopCluster(cl)
}

for (i in 1:3) {
  for (j in 1:6) {
    fdr<-0
    power<-0
    for (k in 1:times) {
      fdr<-fdr+sum(results[[i]][[j,k]]>250*(i+1)*w_1)/(max(1,length(results[[i]][[j,k]]))*times)
      power<-power+sum(results[[i]][[j,k]]<=250*(i+1)*w_1)/(250*(i+1)*w_1*times)
    }
    FDR[i,j]<-fdr
    Power[i,j]<-power
  }
}
