##1. Package loading
##---------------------------------------------------------------------------
library(enrichR)
library(reshape2)
library(ggplot2)
##---------------------------------------------------------------------------

##2. Read the results
##---------------------------------------------------------------------------
load("./PLNet_res_list.Rdata")
load("./VPLN_res_list.Rdata")
A.stim<-PLNet_res_list[[1]]
fits<-VPLN_res_list[[1]]
bm.stim<-fits$models
##---------------------------------------------------------------------------

##3. Network choosing by density
##---------------------------------------------------------------------------
##Choosing the Network density (using 0.03,0.05,0.07)
den<-0.03
##PLNet
density<-c()
for(i in 1:length(A.stim$lambda_vec)){
  density[i]<-(sum(ifelse(A.stim$Omega_est[[i]]!=0,1,0))-200)/200/199
}
index<-which.min(abs(density-den))
omega.stim<-as.matrix(A.stim$Omega_est[[index]])

##calculate partial correlation
pomega.stim<-as.matrix(omega.stim)
n<-dim(pomega.stim)[1]
for (i in 1:n){
  for (j in 1:n){
    pomega.stim[i,j]<-omega.stim[i,j]/sqrt(omega.stim[i,i]*omega.stim[j,j])*(-1)
  }
}
##---------------------------------------------------------------------------

##4. GO analysis to find 4 modules
##---------------------------------------------------------------------------
dbs <- listEnrichrDbs()
dbs <- c("GO_Molecular_Function_2018", "GO_Cellular_Component_2018", "GO_Biological_Process_2018" ,"KEGG_2019_Human")
enriched.D <- enrichr(row.names(omega.stim), dbs)
bp.stim <- enriched.D[["GO_Biological_Process_2018"]]
geneclu1<-strsplit(bp.stim$Genes[3],";")[[1]]
clu1index<-which(row.names(omega.stim)%in%geneclu1)

enriched.D1 <- enrichr(row.names(omega.stim)[-clu1index], dbs)
bp.stim1 <- enriched.D1[["GO_Biological_Process_2018"]]
geneclu2<-strsplit(bp.stim1$Genes[1],";")[[1]]
clu2index<-which(row.names(omega.stim)%in%geneclu2)

enriched.D2 <- enrichr(row.names(omega.stim)[-c(clu1index,clu2index)], dbs)
bp.stim2 <- enriched.D2[["GO_Biological_Process_2018"]]
geneclu3<-strsplit(bp.stim2$Genes[7],";")[[1]]
clu3index<-which(row.names(omega.stim)%in%geneclu3)

enriched.D3 <- enrichr(row.names(omega.stim)[-c(clu1index,clu2index,clu3index)], dbs)
bp.stim3 <- enriched.D3[["GO_Biological_Process_2018"]]
geneclu4<-union(strsplit(bp.stim3$Genes[c(26)],";")[[1]],strsplit(bp.stim3$Genes[c(755)],";")[[1]])
clu4index<-which(row.names(omega.stim)%in%geneclu4)

cluindex<-c(clu1index,clu2index,clu3index,clu4index)
##---------------------------------------------------------------------------

##5. calculate within-between connection ratios of 4 modules
##---------------------------------------------------------------------------
##calculate unweighted ratios
sum1<-sum(abs(pomega.stim[clu1index,clu1index]))-length(clu1index)
sum2<-sum(abs(pomega.stim[clu2index,clu2index]))-length(clu2index)
sum3<-sum(abs(pomega.stim[clu3index,clu3index]))-length(clu3index)
sum4<-sum(abs(pomega.stim[clu4index,clu4index]))-length(clu4index)
sum1all<-sum(abs(pomega.stim[clu1index,cluindex]))-length(clu1index)
sum2all<-sum(abs(pomega.stim[clu2index,cluindex]))-length(clu2index)
sum3all<-sum(abs(pomega.stim[clu3index,cluindex]))-length(clu3index)
sum4all<-sum(abs(pomega.stim[clu4index,cluindex]))-length(clu4index)
ratio.weight.PLNet<-c(sum1,sum2,sum3,sum4)/c(sum1all,sum2all,sum3all,sum4all)
##calculate unweighted ratios
pomega.stim.un<-ifelse(pomega.stim!=0,1,0)
sum1<-sum(abs(pomega.stim.un[clu1index,clu1index]))-length(clu1index)
sum2<-sum(abs(pomega.stim.un[clu2index,clu2index]))-length(clu2index)
sum3<-sum(abs(pomega.stim.un[clu3index,clu3index]))-length(clu3index)
sum4<-sum(abs(pomega.stim.un[clu4index,clu4index]))-length(clu4index)
sum1all<-sum(abs(pomega.stim.un[clu1index,cluindex]))-length(clu1index)
sum2all<-sum(abs(pomega.stim.un[clu2index,cluindex]))-length(clu2index)
sum3all<-sum(abs(pomega.stim.un[clu3index,cluindex]))-length(clu3index)
sum4all<-sum(abs(pomega.stim.un[clu4index,cluindex]))-length(clu4index)
ratio.unweight.PLNet<-c(sum1,sum2,sum3,sum4)/c(sum1all,sum2all,sum3all,sum4all)
##---------------------------------------------------------------------------


##6. plot the heatmap
##---------------------------------------------------------------------------
pomega.stim<-pomega.stim[cluindex,cluindex]
omega.melt<-melt(as.matrix(pomega.stim))
omega.melt$value[which(omega.melt$value==-1)]<-0
plot_PLNet<-ggplot(omega.melt, aes(x=Var2,y=Var1))+geom_tile(aes(fill=value))+
  scale_fill_gradient2(low =  "blue" ,mid="white", high =  "red" )+
  #theme(axis.text.x = element_blank())+
  #theme(axis.text.y = element_text(size = 5,vjust = 0.3, hjust = 0.5))+
  geom_rect(aes(xmin = 1 - 0.5, xmax = 35 + 0.5, ymin = 0, ymax = 0),
            fill = "transparent", color = '#E64E00', size = 3)+
  geom_rect(aes(xmin = 36 - 0.5, xmax = 50 + 0.5, ymin = 0, ymax = 0),
            fill = "transparent", color = '#FF7F00', size = 3)+
  geom_rect(aes(xmin = 51 - 0.5, xmax = 61 + 0.5, ymin = 0, ymax = 0),
            fill = "transparent", color = '#65B48E', size = 3)+
  geom_rect(aes(xmin = 62 - 0.5, xmax = 73 + 0.5, ymin = 0, ymax = 0),
            fill = "transparent", color = '#3E5CC5', size = 3)+
  geom_rect(aes(ymin = 1 - 0.5, ymax = 35 + 0.5, xmin = 0, xmax = 0),
            fill = "transparent", color = '#E64E00', size = 3)+
  geom_rect(aes(ymin = 36 - 0.5, ymax = 50 + 0.5, xmin = 0, xmax = 0),
            fill = "transparent", color = '#FF7F00', size = 3)+
  geom_rect(aes(ymin = 51 - 0.5, ymax = 61 + 0.5, xmin = 0, xmax = 0),
            fill = "transparent", color = '#65B48E', size = 3)+
  geom_rect(aes(ymin = 62 - 0.5, ymax = 73 + 0.5, xmin = 0, xmax = 0),
            fill = "transparent", color = '#3E5CC5', size = 3)+
  geom_rect(aes(xmin = 1 - 0.5, xmax = 35 + 0.5, ymin = 1 - 0.5, ymax = 35 + 0.5),
            fill = "transparent", color = '#E64E00', size = 0.5)+
  geom_rect(aes(xmin = 36 - 0.5, xmax = 50 + 0.5, ymin = 36 - 0.5, ymax = 50 + 0.5),
            fill = "transparent", color = '#FF7F00', size = 0.5)+
  geom_rect(aes(xmin = 51 - 0.5, xmax = 61 + 0.5, ymin = 51 - 0.5, ymax = 61 + 0.5),
            fill = "transparent", color = '#65B48E', size = 0.5)+
  geom_rect(aes(xmin = 62 - 0.5, xmax = 73 + 0.5, ymin = 62 - 0.5, ymax = 73 + 0.5),
            fill = "transparent", color = '#3E5CC5', size = 0.5)+
  labs(y=element_blank(),x=element_blank())+
  theme(axis.text.x = element_blank(),axis.ticks.x=element_blank(),axis.ticks.y=element_blank(),axis.text.y = element_blank())
##---------------------------------------------------------------------------

##7. Repeat the procedures for VPLN
##---------------------------------------------------------------------------
##VPLN
density<-c()
for(i in 1:80){
  density[i]<-bm.stim[[i]]$density
}
index<-which.min(abs(density-den))
pre_V_choose.stim<-bm.stim[[index]]$model_par$Omega
omega.stim<-pre_V_choose.stim
##calculate partial correlation
pomega.stim<-as.matrix(omega.stim)
n<-dim(pomega.stim)[1]
for (i in 1:n){
  for (j in 1:n){
    pomega.stim[i,j]<-omega.stim[i,j]/sqrt(omega.stim[i,i]*omega.stim[j,j])*(-1)
  }
}

##calculate within-between connection ratios of 4 modules
##calculate unweighted ratios
sum1<-sum(abs(pomega.stim[clu1index,clu1index]))-length(clu1index)
sum2<-sum(abs(pomega.stim[clu2index,clu2index]))-length(clu2index)
sum3<-sum(abs(pomega.stim[clu3index,clu3index]))-length(clu3index)
sum4<-sum(abs(pomega.stim[clu4index,clu4index]))-length(clu4index)
sum1all<-sum(abs(pomega.stim[clu1index,cluindex]))-length(clu1index)
sum2all<-sum(abs(pomega.stim[clu2index,cluindex]))-length(clu2index)
sum3all<-sum(abs(pomega.stim[clu3index,cluindex]))-length(clu3index)
sum4all<-sum(abs(pomega.stim[clu4index,cluindex]))-length(clu4index)
ratio.weight.VPLN<-c(sum1,sum2,sum3,sum4)/c(sum1all,sum2all,sum3all,sum4all)
##calculate unweighted ratios
pomega.stim.un<-ifelse(pomega.stim!=0,1,0)
sum1<-sum(abs(pomega.stim.un[clu1index,clu1index]))-length(clu1index)
sum2<-sum(abs(pomega.stim.un[clu2index,clu2index]))-length(clu2index)
sum3<-sum(abs(pomega.stim.un[clu3index,clu3index]))-length(clu3index)
sum4<-sum(abs(pomega.stim.un[clu4index,clu4index]))-length(clu4index)
sum1all<-sum(abs(pomega.stim.un[clu1index,cluindex]))-length(clu1index)
sum2all<-sum(abs(pomega.stim.un[clu2index,cluindex]))-length(clu2index)
sum3all<-sum(abs(pomega.stim.un[clu3index,cluindex]))-length(clu3index)
sum4all<-sum(abs(pomega.stim.un[clu4index,cluindex]))-length(clu4index)
ratio.unweight.VPLN<-c(sum1,sum2,sum3,sum4)/c(sum1all,sum2all,sum3all,sum4all)


##plot the heatmap
pomega.stim<-pomega.stim[cluindex,cluindex]
omega.melt<-melt(as.matrix(pomega.stim))
omega.melt$value[which(omega.melt$value==-1)]<-0
plot_VPLN<-ggplot(omega.melt, aes(x=Var2,y=Var1))+geom_tile(aes(fill=value))+
  scale_fill_gradient2(low =  "blue" ,mid="white", high =  "red" )+
  #theme(axis.text.x = element_blank())+
  #theme(axis.text.y = element_text(size = 5,vjust = 0.3, hjust = 0.5))+
  geom_rect(aes(xmin = 1 - 0.5, xmax = 35 + 0.5, ymin = 0, ymax = 0),
            fill = "transparent", color = '#E64E00', size = 3)+
  geom_rect(aes(xmin = 36 - 0.5, xmax = 50 + 0.5, ymin = 0, ymax = 0),
            fill = "transparent", color = '#FF7F00', size = 3)+
  geom_rect(aes(xmin = 51 - 0.5, xmax = 61 + 0.5, ymin = 0, ymax = 0),
            fill = "transparent", color = '#65B48E', size = 3)+
  geom_rect(aes(xmin = 62 - 0.5, xmax = 73 + 0.5, ymin = 0, ymax = 0),
            fill = "transparent", color = '#3E5CC5', size = 3)+
  geom_rect(aes(ymin = 1 - 0.5, ymax = 35 + 0.5, xmin = 0, xmax = 0),
            fill = "transparent", color = '#E64E00', size = 3)+
  geom_rect(aes(ymin = 36 - 0.5, ymax = 50 + 0.5, xmin = 0, xmax = 0),
            fill = "transparent", color = '#FF7F00', size = 3)+
  geom_rect(aes(ymin = 51 - 0.5, ymax = 61 + 0.5, xmin = 0, xmax = 0),
            fill = "transparent", color = '#65B48E', size = 3)+
  geom_rect(aes(ymin = 62 - 0.5, ymax = 73 + 0.5, xmin = 0, xmax = 0),
            fill = "transparent", color = '#3E5CC5', size = 3)+
  geom_rect(aes(xmin = 1 - 0.5, xmax = 35 + 0.5, ymin = 1 - 0.5, ymax = 35 + 0.5),
            fill = "transparent", color = '#E64E00', size = 0.5)+
  geom_rect(aes(xmin = 36 - 0.5, xmax = 50 + 0.5, ymin = 36 - 0.5, ymax = 50 + 0.5),
            fill = "transparent", color = '#FF7F00', size = 0.5)+
  geom_rect(aes(xmin = 51 - 0.5, xmax = 61 + 0.5, ymin = 51 - 0.5, ymax = 61 + 0.5),
            fill = "transparent", color = '#65B48E', size = 0.5)+
  geom_rect(aes(xmin = 62 - 0.5, xmax = 73 + 0.5, ymin = 62 - 0.5, ymax = 73 + 0.5),
            fill = "transparent", color = '#3E5CC5', size = 0.5)+
  labs(y=element_blank(),x=element_blank())+
  theme(axis.text.x = element_blank(),axis.ticks.x=element_blank(),axis.ticks.y=element_blank(),axis.text.y = element_blank())
##---------------------------------------------------------------------------

##8. Save the result of evaluation
##---------------------------------------------------------------------------
ratiotable<-matrix(c(ratio.weight.PLNet,ratio.weight.VPLN,ratio.unweight.PLNet,ratio.unweight.VPLN),nrow = 4,ncol = 4,byrow = TRUE)
base::save(ratiotable,file=paste("ratiotable_00",round(den*100),".Rdata",sep = ""))
ggsave(paste("PLNet_00",round(den*100),".pdf",sep = ""), plot = plot_PLNet,width = 6.85,height = 6)
ggsave(paste("VPLN_00",round(den*100),".pdf",sep = ""), plot = plot_VPLN,width = 6.85,height = 6)
##---------------------------------------------------------------------------