##1. Packages loading
##----------------------------------------------------------------------------------
library(PLNet)
library(orthopolynom)
library(Rcpp)
library(CVXR)
library(parallel)
library(Seurat)
library(SeuratData)
library(GMPR)
library(PLNmodels)
library(glasso)
library(doParallel)
##----------------------------------------------------------------------------------

clnum <- detectCores()          # Get the number of CPU core
cat('cores', clnum, '\n')

cl <- makeCluster(20)
registerDoParallel(cl)
###########################
mle_newton<-function(data_use,S_depth,core_num = 1,k_max = 10){
  mlemu<-list()
  mlesigmahat<-list()
  mu_grad<-list()
  sigma_grad<-list()
  early_stop_iternum<-1
  ##
  ##1. Perparsion
  ##-------------------------------------------------------------------------------
  dim_use<-ncol(data_use)
  sample_size<-nrow(data_use)
  ##-------------------------------------------------------------------------------
  
  ##2. Initialized estiamation of mu and sigma
  ##-------------------------------------------------------------------------------
  data_use_nor<-data_use/as.vector(S_depth)
  log_vec1<-as.vector(log(colMeans(data_use_nor)))
  sigmahat<-t(log((t(data_use_nor) %*% data_use_nor)/nrow(data_use_nor)) - log_vec1) - log_vec1
  ##
  diag(sigmahat)<-(log(colMeans((data_use * (data_use - 1)) / as.vector(S_depth^2))) - 2 * log(colMeans(data_use/as.vector(S_depth))))
  mu<-log(colMeans(data_use/as.vector(S_depth))) - diag(sigmahat)/2
  ##judge if the diagonal of sigmahat is all positive
  neg_index<-which(diag(sigmahat)<=0)
  if(length(neg_index)>0){
    diag(sigmahat)[neg_index]<-min(diag(sigmahat)[setdiff(1:dim(sigmahat)[1],neg_index)])
  }
  ##
  ##adjust
  allzero_set<-which(colSums(data_use) == 0)
  zero_plus_one_set<-setdiff(which(colSums(ifelse(data_use>=2,1,0)) == 0),allzero_set)
  choose_index<-setdiff(1:ncol(data_use),allzero_set)
  sigma_me1<-sigmahat[choose_index,choose_index]
  diag(sigma_me1)[which(choose_index %in% zero_plus_one_set)]<-0
  ##
  isinfinite_mat<-ifelse(is.infinite(sigma_me1),1,0)
  isinfinite_mat[lower.tri(isinfinite_mat)]<-0
  min_vec<-rep(NA,ncol(sigma_me1))
  for(i in 1:ncol(sigma_me1)){
    min_vec[i] <- min(sigma_me1[i,-i][is.finite(c(sigma_me1[i,-i]))])
  }
  for(i in 1:(ncol(sigma_me1)-1)){
    index_select<-which(isinfinite_mat[i,]==1)
    index_select<-index_select[which(index_select>i)]
    if(length(index_select)>0){
      for(j in 1:length(index_select)){
        isinfinite_mat[i,index_select[j]]<-min(min_vec[i],min_vec[index_select[j]])
      }
    }
  }
  isinfinite_mat<-isinfinite_mat + t(isinfinite_mat)
  diag(isinfinite_mat)<-0
  sigma_me1<-ifelse(is.infinite(sigma_me1),0,sigma_me1)
  sigma_me1<-sigma_me1 + isinfinite_mat
  sigmahat<-sigma_me1
  ##
  mlemu[[1]]<-mu
  mlesigmahat[[1]]<-sigmahat
  ##
  ##-------------------------------------------------------------------------------
  
  ##3. Use orthopolynom package for solving the Hermite polynomials
  ##-------------------------------------------------------------------------------
  t.root<-polynomial.roots(monic.polynomial.recurrences(hermite.h.recurrences(10, normalized=FALSE)))
  omega.root<-2^9*factorial(10)*sqrt(pi)/100/polynomial.values(hermite.h.polynomials(10,normalized=FALSE), t.root[[11]])[[10]]^2
  ##-------------------------------------------------------------------------------
  
  ##4. Calculate the gradiant, hessian and delta of mu and sigma
  ##-------------------------------------------------------------------------------
  gradiant_iter_mat<-matrix(NA,nrow = k_max,ncol = ncol(sigma_me1))
  gradiant_iter_mat_sigma<-matrix(NA,nrow = k_max,ncol = ncol(sigma_me1))
  for (k in 2:(k_max + 1)){
    # for (k in 2:4){
    # print(k)
    ##re-initial
    sigmahat<-mlesigmahat[[k-1]]
    mu<-mlemu[[k-1]]
    ##
    sigma_diag_max<-max(diag(sigmahat))
    z.root_mat<-t(t(matrix(t.root[[11]]*sqrt(2),ncol = 1) %*% matrix(diag(sigmahat),nrow = 1)) + mu)
    m.root_mat<-matrix( omega.root*exp(t.root[[11]]^2)*sqrt(2), ncol = 1) %*% matrix(diag(sigmahat),nrow = 1)
    ##
    share12_mat<-t(t(z.root_mat) - mu)
    share11_mat<-exp(t((-1) * (((t(share12_mat)))^2)/(2 * as.vector(diag(sigmahat)))))
    ##
    root1_array<-array(0,dim = c(10,dim(data_use)))
    for(root_index in 1:10){
      root1_array[root_index,,]<-t(t(data_use) * as.vector(z.root_mat[root_index,])) - matrix(S_depth, ncol = 1) %*% matrix(exp(z.root_mat[root_index,]),nrow = 1)  
    }
    max_mat<-apply(root1_array,MARGIN = c(2,3),max)
    share1_array<-array(0,dim = c(10,dim(data_use)))
    for(root_index in 1:10){
      share1_array[root_index,,]<-exp(root1_array[root_index,,] - max_mat)
    }
    ##
    sigma_diag<-diag(sigmahat)
    gz.up.mu_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(t(t(share1_array[,sample_index,] * share12_mat) / as.vector(sigma_diag)) * share11_mat)},simplify = "array")
    
    array_temp<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(t(t(share1_array[,sample_index,] * (share12_mat)^2) / as.vector(sigma_diag^2) ))},simplify = "array")
    
    gz.up.mu.2_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(t(t(share1_array[,sample_index,] * (share12_mat)^2) / as.vector(sigma_diag^2)) * share11_mat)},simplify = "array")
    
    gz.up.sigma_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(gz.up.mu.2_array[,,sample_index]/2)},simplify = "array")
    
    mid_mat1<-t(t(share12_mat^4) /(4 * (as.vector(sigma_diag))^4)) - t(t(share12_mat^2) /((as.vector(sigma_diag))^3))
    gz.up.sigma.2_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(share1_array[,sample_index,] *
               (mid_mat1)*
               share11_mat)},simplify = "array")
    
    gz.down_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(share1_array[,sample_index,] * share11_mat)},simplify = "array")
    
    mid_mat2<-t(t(share12_mat^3) /(2 * (as.vector(sigma_diag))^3)) - t(t(share12_mat) /((as.vector(sigma_diag))^2))
    gz.up.inter_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(share1_array[,sample_index,] *
               (mid_mat2) *
               share11_mat)},simplify = "array")
    ##
    
    mat_temp1<-t(sapply(X = 1:sample_size,FUN = function(sample_index){colSums(gz.down_array[,,sample_index] * m.root_mat)
    }))
    
    share2_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.mu_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    share3_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.sigma_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    share4_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.sigma.2_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    share5_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.inter_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    share6_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.mu.2_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    
    share2_mat<-ifelse(is.na(share2_mat),0,share2_mat)
    share3_mat<-ifelse(is.na(share3_mat),0,share3_mat)
    share4_mat<-ifelse(is.na(share4_mat),0,share4_mat)
    share5_mat<-ifelse(is.na(share5_mat),0,share5_mat)
    share6_mat<-ifelse(is.na(share6_mat),0,share6_mat)
    ##
    gradiant.mu<-colMeans(share2_mat)
    gradiant.sigma<-colMeans(share3_mat)-sigma_diag^(-1)/2
    Hessian.mu<-colMeans(share6_mat)- sigma_diag^(-1) - colMeans((share2_mat)^2)
    Hessian.sigma<- colMeans(share4_mat) - colMeans((share3_mat)^2) + (0.5) * (sigma_diag^(-2))
    Hessian.int<-colMeans(share5_mat) - colMeans(share2_mat * share3_mat)
    ##
    mu_grad[[k-1]]<-gradiant.mu
    sigma_grad[[k-1]]<-gradiant.sigma
    ##
    delta.mu<-c()
    delta.sigma<-c()
    for(dim_index in 1:dim_use){
      Hessian.matrix<-matrix(c(Hessian.mu[dim_index],Hessian.int[dim_index],Hessian.int[dim_index],Hessian.sigma[dim_index]),2,2)
      gradiant<-matrix(c(gradiant.mu[dim_index],gradiant.sigma[dim_index]),2,1)
      gradiant_iter_mat[k-1,dim_index]<-gradiant[1,1]
      gradiant_iter_mat_sigma[k-1,dim_index]<-gradiant[2,1]
      # ##test
      # if(dim_index %in% c(35)){
      #   print(paste("dim_index = ",dim_index,sep = ""))
      #   # gradiant_list[[k-1]][[dim_index-2]]<-gradiant
      #   # hess_list[[k-1]][[dim_index-2]]<-Hessian.matrix
      #   print(gradiant)
      #   print(Hessian.matrix)
      # }
      # ##
      Hessian.matrix_solve<-solve(Hessian.matrix)
      mat_temp<-Hessian.matrix_solve%*%gradiant
      delta.mu[dim_index]<-mat_temp[1]
      delta.sigma[dim_index]<-mat_temp[2]
      ##
      if (sigmahat[dim_index,dim_index]-delta.sigma[dim_index]<=0){
        delta.mu[dim_index]<-gradiant.mu[dim_index]/Hessian.mu[dim_index]
        delta.sigma[dim_index]<-gradiant.sigma[dim_index]/Hessian.sigma[dim_index]
        # if (sigmahat[dim_index,dim_index]-delta.sigma[dim_index]<=0){
        #   delta.mu[dim_index]<-0
        #   delta.sigma[dim_index]<-0
        # }
      }
      if (abs(delta.sigma[dim_index]) > sigma_diag_max){
        delta.mu[dim_index]<-0
        delta.sigma[dim_index]<-0
      }
    }
    
    mlemu1<-mu-delta.mu
    ##
    sigmahat<-mlesigmahat[[k-1]]
    ##judge if the diagonal of sigmahat is all positive
    neg_index<-which(diag(sigmahat)<=0)
    if(length(neg_index)>0){
      diag(sigmahat)[neg_index]<-min(diag(sigmahat)[setdiff(1:dim(sigmahat)[1],neg_index)])
    }
    ##
    base_diag<-delta.sigma
    reduce_diag<-diag(sigmahat) - base_diag
    adjust_index<-which(reduce_diag<=0)
    if(length(adjust_index)>0){
      #adjust the step size
      reduce_mat<-diag(delta.sigma)
      diag(reduce_mat)[adjust_index]<-0
      mlesigmahat10<-sigmahat - reduce_mat
      diag(mlesigmahat10)[adjust_index]<-min(diag(mlesigmahat10)[-adjust_index])
      mlesigmahat1<-mlesigmahat10
    }else{
      mlesigmahat1<-sigmahat-diag(delta.sigma) 
    }
    ##
    mlesigmahat1_diag<-diag(mlesigmahat1)
    ##
    ##
    # print(summary(diag(mlesigmahat1)))
    # print(summary(mlesigmahat1[upper.tri(mlesigmahat1)]))
    # print(summary(as.vector(mlemu1)))
    ##
    mlemu[[k]]<-mlemu1
    mlesigmahat[[k]]<-mlesigmahat1
  }
  if_convergence_feature<-ifelse(abs(sigma_grad[[k_max]])<1e-3,TRUE,FALSE)
  ##
  update_iter_mat<-matrix(NA,nrow = nrow(gradiant_iter_mat),ncol = ncol(gradiant_iter_mat))
  for(dim_index in 1:ncol(gradiant_iter_mat)){
    vec_use<-abs(gradiant_iter_mat[,dim_index])
    vec_use_sigma<-abs(gradiant_iter_mat_sigma[,dim_index])
    aaa_vec<-which(diff(vec_use)>0)
    aaa_vec_sigma<-which(diff(vec_use_sigma)>0)
    if((length(aaa_vec)>0) | (length(aaa_vec_sigma)>0)){
      if(length(aaa_vec)>0){
        min1<-min(aaa_vec)
      }else{
        min1<-k_max
      }
      ##
      if(length(aaa_vec_sigma)>0){
        min2<-min(aaa_vec_sigma)
      }else{
        min2<-k_max
      }
      update_iter_mat[min(c(min1 + 1, min2 + 1,k_max)) :k_max,dim_index]<-FALSE
      update_iter_mat[1 :min(c(min2,min1)),dim_index]<-TRUE
      
    }else{
      update_iter_mat[,dim_index]<-TRUE
    }
  }
  # ##add
  # nonconvergence_index<-which(if_convergence_feature == FALSE)
  # if(length(nonconvergence_index)>0){
  #   for(dim_index in nonconvergence_index){
  #     update_iter_mat[2:nrow(update_iter_mat),dim_index]<-FALSE
  #   } 
  # }
  ##--------------------------------------------------------------------------
  gradiant_iter_mat<-matrix(NA,nrow = k_max,ncol = ncol(sigma_me1))
  if_update_feature<-rep(TRUE,ncol(sigma_me1))
  for (k in 2:(k_max + 1)){
    # for (k in 2:(early_stop_iternum)){
    # for (k in 2:4){
    print(k)
    ##re-initial
    sigmahat<-mlesigmahat[[k-1]]
    mu<-mlemu[[k-1]]
    ##
    sigma_diag_max<-max(diag(sigmahat))
    z.root_mat<-t(t(matrix(t.root[[11]]*sqrt(2),ncol = 1) %*% matrix(diag(sigmahat),nrow = 1)) + mu)
    m.root_mat<-matrix( omega.root*exp(t.root[[11]]^2)*sqrt(2), ncol = 1) %*% matrix(diag(sigmahat),nrow = 1)
    ##
    share12_mat<-t(t(z.root_mat) - mu)
    share11_mat<-exp(t((-1) * (((t(share12_mat)))^2)/(2 * as.vector(diag(sigmahat)))))
    ##
    root1_array<-array(0,dim = c(10,dim(data_use)))
    for(root_index in 1:10){
      root1_array[root_index,,]<-t(t(data_use) * as.vector(z.root_mat[root_index,])) - matrix(S_depth, ncol = 1) %*% matrix(exp(z.root_mat[root_index,]),nrow = 1)  
    }
    max_mat<-apply(root1_array,MARGIN = c(2,3),max)
    share1_array<-array(0,dim = c(10,dim(data_use)))
    for(root_index in 1:10){
      share1_array[root_index,,]<-exp(root1_array[root_index,,] - max_mat)
    }
    ##
    sigma_diag<-diag(sigmahat)
    gz.up.mu_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(t(t(share1_array[,sample_index,] * share12_mat) / as.vector(sigma_diag)) * share11_mat)},simplify = "array")
    
    array_temp<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(t(t(share1_array[,sample_index,] * (share12_mat)^2) / as.vector(sigma_diag^2) ))},simplify = "array")
    
    gz.up.mu.2_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(t(t(share1_array[,sample_index,] * (share12_mat)^2) / as.vector(sigma_diag^2)) * share11_mat)},simplify = "array")
    
    gz.up.sigma_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(gz.up.mu.2_array[,,sample_index]/2)},simplify = "array")
    
    mid_mat1<-t(t(share12_mat^4) /(4 * (as.vector(sigma_diag))^4)) - t(t(share12_mat^2) /((as.vector(sigma_diag))^3))
    gz.up.sigma.2_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(share1_array[,sample_index,] *
               (mid_mat1)*
               share11_mat)},simplify = "array")
    
    gz.down_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(share1_array[,sample_index,] * share11_mat)},simplify = "array")
    
    mid_mat2<-t(t(share12_mat^3) /(2 * (as.vector(sigma_diag))^3)) - t(t(share12_mat) /((as.vector(sigma_diag))^2))
    gz.up.inter_array<-sapply(X = 1:sample_size,FUN = function(sample_index){
      return(share1_array[,sample_index,] *
               (mid_mat2) *
               share11_mat)},simplify = "array")
    ##
    
    mat_temp1<-t(sapply(X = 1:sample_size,FUN = function(sample_index){colSums(gz.down_array[,,sample_index] * m.root_mat)
    }))
    
    share2_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.mu_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    share3_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.sigma_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    share4_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.sigma.2_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    share5_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.inter_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    share6_mat<-t(sapply(X = 1:sample_size,FUN = function(sample_index){
      colSums(gz.up.mu.2_array[,,sample_index] * m.root_mat)/mat_temp1[sample_index,]
    }))
    
    share2_mat<-ifelse(is.na(share2_mat),0,share2_mat)
    share3_mat<-ifelse(is.na(share3_mat),0,share3_mat)
    share4_mat<-ifelse(is.na(share4_mat),0,share4_mat)
    share5_mat<-ifelse(is.na(share5_mat),0,share5_mat)
    share6_mat<-ifelse(is.na(share6_mat),0,share6_mat)
    ##
    gradiant.mu<-colMeans(share2_mat)
    gradiant.sigma<-colMeans(share3_mat)-sigma_diag^(-1)/2
    Hessian.mu<-colMeans(share6_mat)- sigma_diag^(-1) - colMeans((share2_mat)^2)
    Hessian.sigma<- colMeans(share4_mat) - colMeans((share3_mat)^2) + (0.5) * (sigma_diag^(-2))
    Hessian.int<-colMeans(share5_mat) - colMeans(share2_mat * share3_mat)
    ##
    mu_grad[[k-1]]<-gradiant.mu
    sigma_grad[[k-1]]<-gradiant.sigma
    ##
    delta.mu<-c()
    delta.sigma<-c()
    for(dim_index in 1:dim_use){
      Hessian.matrix<-matrix(c(Hessian.mu[dim_index],Hessian.int[dim_index],Hessian.int[dim_index],Hessian.sigma[dim_index]),2,2)
      gradiant<-matrix(c(gradiant.mu[dim_index],gradiant.sigma[dim_index]),2,1)
      gradiant_iter_mat[k-1,dim_index]<-gradiant[1,1]
      # ##test
      # if(dim_index %in% c(3,4)){
      #   print(paste("dim_index = ",dim_index,sep = ""))
      #   # gradiant_list[[k-1]][[dim_index-2]]<-gradiant
      #   # hess_list[[k-1]][[dim_index-2]]<-Hessian.matrix
      #   print(gradiant)
      #   print(Hessian.matrix)
      # }
      # ##
      Hessian.matrix_solve<-solve(Hessian.matrix)
      mat_temp<-Hessian.matrix_solve%*%gradiant
      delta.mu[dim_index]<-mat_temp[1]
      delta.sigma[dim_index]<-mat_temp[2]
      ##
      # if((k>=(early_stop_iternum+1)) & (if_convergence_feature[dim_index] == FALSE)){
      if((if_convergence_feature[dim_index] == FALSE)){
        # # vec_temp<-gradiant_iter_mat[(k+1 - early_stop_iternum): early_stop_iternum,dim_index]
        # vec_temp<-gradiant_iter_mat[(k - 2): (k - 1),dim_index]
        # # if((which.min(abs(vec_temp)) == 2) | (if_update_feature[dim_index] == FALSE)){
        # # diff_sign<-sign(diff(vec_temp))
        # # if((diff_sign[1] != diff_sign[2]) | (if_update_feature[dim_index] == FALSE)){
        # if((abs(vec_temp)[2]>abs(vec_temp)[1]) | (if_update_feature[dim_index] == FALSE)){
        if((update_iter_mat[(k-1),dim_index] == FALSE) | (if_update_feature[dim_index] == FALSE)){
          if_update_feature[dim_index]<-FALSE
          delta.mu[dim_index]<-0
          delta.sigma[dim_index]<-0
        }else{
          if (sigmahat[dim_index,dim_index]-delta.sigma[dim_index]<=0){
            delta.mu[dim_index]<-gradiant.mu[dim_index]/Hessian.mu[dim_index]
            delta.sigma[dim_index]<-gradiant.sigma[dim_index]/Hessian.sigma[dim_index]
            # if (sigmahat[dim_index,dim_index]-delta.sigma[dim_index]<=0){
            #   delta.mu[dim_index]<-0
            #   delta.sigma[dim_index]<-0
            # }
          }
          if (abs(delta.sigma[dim_index]) > sigma_diag_max){
            delta.mu[dim_index]<-0
            delta.sigma[dim_index]<-0
          }
        }
      }else{
        if(if_update_feature[dim_index] == TRUE){
          if (sigmahat[dim_index,dim_index]-delta.sigma[dim_index]<=0){
            delta.mu[dim_index]<-gradiant.mu[dim_index]/Hessian.mu[dim_index]
            delta.sigma[dim_index]<-gradiant.sigma[dim_index]/Hessian.sigma[dim_index]
            # if (sigmahat[dim_index,dim_index]-delta.sigma[dim_index]<=0){
            #   delta.mu[dim_index]<-0
            #   delta.sigma[dim_index]<-0
            # }
          }
          if (abs(delta.sigma[dim_index]) > sigma_diag_max){
            delta.mu[dim_index]<-0
            delta.sigma[dim_index]<-0
          }
        }
      }
      ##
    }
    #####
    
    mlemu1<-mu-delta.mu
    ##
    sigmahat<-mlesigmahat[[k-1]]
    ##judge if the diagonal of sigmahat is all positive
    neg_index<-which(diag(sigmahat)<=0)
    if(length(neg_index)>0){
      diag(sigmahat)[neg_index]<-min(diag(sigmahat)[setdiff(1:dim(sigmahat)[1],neg_index)])
    }
    ##
    base_diag<-delta.sigma
    reduce_diag<-diag(sigmahat) - base_diag
    adjust_index<-which(reduce_diag<=0)
    if(length(adjust_index)>0){
      #adjust the step size
      reduce_mat<-diag(delta.sigma)
      diag(reduce_mat)[adjust_index]<-0
      mlesigmahat10<-sigmahat - reduce_mat
      diag(mlesigmahat10)[adjust_index]<-min(diag(mlesigmahat10)[-adjust_index])
      mlesigmahat1<-mlesigmahat10
    }else{
      mlesigmahat1<-sigmahat-diag(delta.sigma) 
    }
    ##
    mlesigmahat1_diag<-diag(mlesigmahat1)
    ##
    # time_1<-Sys.time()
    
    if(k == (k_max+1)){
      integrate_list<-PLNet::integrated_fun(data_use = data_use, S_depth = S_depth, mlemu1 = mlemu1, mlesigmahat1 = mlesigmahat1,
                                            t_root_vec = t.root[[11]], omega_root_vec = omega.root,
                                            core_num = core_num)
      
      # time_2<-Sys.time()
      # 
      # time_2 - time_1
      gradiant.int<-integrate_list$gradiant_int
      Hessian.int<-integrate_list$Hessian_int
      ##
      gradiant.int<-ifelse(is.na(gradiant.int),0,gradiant.int)
      Hessian.int<-ifelse(is.na(Hessian.int),1,Hessian.int)
      ##
      mlesigmahat1<-basic_fun( data_use = data_use,  S_depth = S_depth,  mlemu1 = mlemu1,  mlesigmahat1 = mlesigmahat1,
                               gradiant_int = gradiant.int,  Hessian_int = Hessian.int,
                               core_num = core_num)
      ##
      mlesigmahat1[lower.tri(mlesigmahat1)]<-0
      mlesigmahat1<-mlesigmahat1 + t(mlesigmahat1)    
      diag(mlesigmahat1)<-mlesigmahat1_diag
    }
    # integrate_list<-PLNet::integrated_fun(data_use = data_use, S_depth = S_depth, mlemu1 = mlemu1, mlesigmahat1 = mlesigmahat1,
    #                                       t_root_vec = t.root[[11]], omega_root_vec = omega.root,
    #                                       core_num = core_num)
    # 
    # # time_2<-Sys.time()
    # # 
    # # time_2 - time_1
    # gradiant.int<-integrate_list$gradiant_int
    # Hessian.int<-integrate_list$Hessian_int
    # ##
    # gradiant.int<-ifelse(is.na(gradiant.int),0,gradiant.int)
    # Hessian.int<-ifelse(is.na(Hessian.int),1,Hessian.int)
    # ##
    # mlesigmahat1<-basic_fun( data_use = data_use,  S_depth = S_depth,  mlemu1 = mlemu1,  mlesigmahat1 = mlesigmahat1,
    #                          gradiant_int = gradiant.int,  Hessian_int = Hessian.int,
    #                          core_num = core_num)
    # ##
    # mlesigmahat1[lower.tri(mlesigmahat1)]<-0
    # mlesigmahat1<-mlesigmahat1 + t(mlesigmahat1)    
    # diag(mlesigmahat1)<-mlesigmahat1_diag
    ##
    print(summary(diag(mlesigmahat1)))
    print(summary(mlesigmahat1[upper.tri(mlesigmahat1)]))
    print(summary(as.vector(mlemu1)))
    ##
    mlemu[[k]]<-mlemu1
    mlesigmahat[[k]]<-mlesigmahat1
    
  }
  
  return(list(mlemu = mlemu,
              mlesigmahat = mlesigmahat,
              mu_grad = mu_grad,
              sigma_grad = sigma_grad,
              update_iter_mat = update_iter_mat))
  
}
##2. Data loading and pre-process
##----------------------------------------------------------------------------------
##2.1 Install the data
SeuratData::InstallData("ifnb")
data("ifnb")
count_mat<-(ifnb@assays$RNA@counts)
anno_vec<-ifnb$seurat_annotations
anno_vec_2<-ifnb$orig.ident

##2.2 Choose the cells of CD14 celltype
CD14_stim<-(which(anno_vec == names(table(anno_vec))[1] & anno_vec_2 == names(table(anno_vec_2))[2]))
count_CD14_stim<-count_mat[,CD14_stim]

##2.3 Choose 200 high variable genes as gene set of interest
P3se_ifnb_stim = CreateSeuratObject(counts = count_CD14_stim,min.cells = 3)
P3se_ifnb_stim <- NormalizeData(P3se_ifnb_stim, normalization.method = "LogNormalize", scale.factor = 10000)
P3se_ifnb_stim <- FindVariableFeatures(P3se_ifnb_stim,nfeatures = 500)
variable_gene_stim<-Seurat::VariableFeatures(P3se_ifnb_stim)
load("./TF_list_all.Rdata")
gene_names<-unique(c(variable_gene_stim[which(variable_gene_stim %in% TF_list_all[[1]])],variable_gene_stim[1:200]))
gene_names<-c(gene_names[which(gene_names %in% TF_list_all[[1]])],setdiff(gene_names,gene_names[which(gene_names %in% TF_list_all[[1]])]))
count_CD14_stim_ifnb<-count_CD14_stim[gene_names,]
datafinal.stim<-count_CD14_stim_ifnb[which(rowSums(ifelse(as.matrix(count_CD14_stim_ifnb)>1,1,0))>=1),]
datafinal.stim<-as.matrix(datafinal.stim)
out_smooth<-datafinal.stim

##2.4 Estimate the library size by GMPR method
S_depth<-GMPR(t(out_smooth),1,1)
##----------------------------------------------------------------------------------

##3. Run PLNet
##----------------------------------------------------------------------------------
##3.1 Estimate the convariance matrix by MLE estimator
print("PLNet_MLE Start.")
k_max<-10
time_1<-Sys.time()
cov_input<-mle_newton(data_use = t(as.matrix(out_smooth)),
                      S_depth = S_depth,
                      k_max = k_max,
                      core_num = 20)
# save(cov_input,file = "./cov_input.Rdata")
time_2<-Sys.time()
time_2 - time_1


##3.2 Estimate the precision matrix by Dtrace loss
time_1<-Sys.time()
PLNet_res_list<-list()
for(if_penalize.diagonal in c(TRUE,FALSE)){
  PLNet_res<-PLNet_main(obs_mat = t(out_smooth),
                        Sd_est = S_depth,
                        n_lambda = 100,
                        penalize.diagonal = if_penalize.diagonal,
                        cov_input = cov_input$mlesigmahat[[2]],
                        weight_mat = NULL,zero_mat = NULL,
                        core_num = 1
  )
  PLNet_res_list[[ifelse(if_penalize.diagonal == TRUE,"penalize.diagonal","not penalize.diagonal")]]<-PLNet_res
}
time_2<-Sys.time()
time_2 - time_1
##
save(PLNet_res_list,file = "./PLNet_res_list_benchmark.Rdata")

##----------------------------------------------------------------------------------

##4. Run VPLN
##----------------------------------------------------------------------------------
print("VPLN Start.")
time_1<-Sys.time()
VPLN_res_list<-list()
for(if_penalize.diagonal in c(TRUE,FALSE)){
  original_list<-list(data_1=as.data.frame(as.matrix(t(out_smooth))),
                      Covariate=as.data.frame(matrix(rep(0,dim(as.matrix(t(out_smooth)))[1]),ncol = 1)))
  rownames(original_list$Covariate)<-rownames(original_list$data_1)
  pre_data<-prepare_data(counts = original_list$data_1,covariates = original_list$Covariate,offset = S_depth)
  names(pre_data)[2]<-"covariates"
  fits <- PLNnetwork(Abundance ~ 1 + offset(log(Offset)), data = pre_data,
                     control_init = list(nPenalties=100,min.ratio=1e-4),
                     control_main = list(xtol_rel = 1e-2, penalize_diagonal = if_penalize.diagonal))


  VPLN_res_list[[ifelse(if_penalize.diagonal == TRUE,"penalize.diagonal","not penalize.diagonal")]]<-fits
}
time_2<-Sys.time()
time_2 - time_1
##
save(VPLN_res_list,file = "./VPLN_res_list_benchmark.Rdata")
##----------------------------------------------------------------------------------