#include <RcppArmadillo.h>
#include <Rcpp.h>
#include <omp.h>
// [[Rcpp::depends(RcppArmadillo)]]
// [[Rcpp::plugins(openmp)]]

using namespace Rcpp;
using namespace arma;

static omp_lock_t lock;

// [[Rcpp::export]]
arma::mat basic_fun(arma::mat data_use, arma::vec S_depth, arma::vec mlemu1, arma::mat mlesigmahat1,
                    arma::mat gradiant_int, arma::mat Hessian_int,
                          int core_num){
  int dim_use = data_use.n_cols;
  int sample_size_use = data_use.n_rows;
  //
  arma::mat sigma0_mat(dim_use,dim_use,fill::zeros);
  arma::mat if_infinite_sigma0_mat(dim_use,dim_use,fill::zeros);
  omp_init_lock(&lock);
  omp_set_lock(&lock);
  
  omp_set_num_threads(core_num);
#pragma omp parallel for
  for (int i=0;i<(dim_use-1);++i) {
    arma::vec data_1 = data_use.col(i);
    double mu1 = mlemu1(i);
    double sigma1 = mlesigmahat1(i,i);
    for (int j=(i+1);j<dim_use;++j) {
      arma::vec data_2 = data_use.col(j);
      double mu2 = mlemu1(j);
      double sigma2 = mlesigmahat1(j,j);
      //
      arma::vec mu_vec(2);
      mu_vec(0) = mu1;
      mu_vec(1) = mu2;
      //
      arma::vec vec1 = (data_1 % data_2)/pow(S_depth,2);
      //
      double vec1_sum = sum(vec1);
      if(vec1_sum == 0){
        if_infinite_sigma0_mat(i,j) = 1;
      }else{
        sigma0_mat(i,j) = log(sum(vec1)/(double)sample_size_use) - (mu1 + sigma1*0.5) - (mu2 + sigma2*0.5);
        sigma0_mat(j,i) = sigma0_mat(i,j);
      }
    }
  }
  omp_destroy_lock(&lock);
  //
  arma::vec min_vec(dim_use,fill::zeros);
  omp_init_lock(&lock);
  omp_set_lock(&lock);
  
  omp_set_num_threads(core_num);
#pragma omp parallel for
  for (int i=0;i<(dim_use);++i){
    arma::vec vec_min_tem = sigma0_mat.col(i);
    double min_value = 1e+5;
    for(int j=0;j<(dim_use);++j){
      if(i!=j){
        if(if_infinite_sigma0_mat(i,j)!=1){
          if(abs(min_value)>abs(sigma0_mat(i,j))){
            min_value = sigma0_mat(i,j);
          }
        }
      }
    }
    //
    min_vec(i) = min_value;
  }
  omp_destroy_lock(&lock);
  //
  omp_init_lock(&lock);
  omp_set_lock(&lock);
  
  omp_set_num_threads(core_num);
#pragma omp parallel for
  for (int i=0;i<(dim_use-1);++i){
    for (int j=(i+1);j<dim_use;++j) {
      if(if_infinite_sigma0_mat(i,j) == 1){
        if(abs(min_vec(i))>abs(min_vec(j))){
          sigma0_mat(i,j) = min_vec(j); 
        }else{
          sigma0_mat(i,j) = min_vec(i); 
        }
      }
    }
  }
  omp_destroy_lock(&lock);
  //
  //
  arma::mat mlesigmahat_res(dim_use,dim_use,fill::zeros);
  omp_init_lock(&lock);
  omp_set_lock(&lock);
  
  omp_set_num_threads(core_num);
#pragma omp parallel for
  for(int ii = 0;ii<(dim_use - 1);++ii){
    arma::vec data1 = data_use.col(ii);
    double mu1 = mlemu1(ii);
    for(int jj = (ii+1);jj<dim_use;++jj){
      arma::vec data2 = data_use.col(jj);
      double mu2 = mlemu1(jj);
      //
      arma::vec mu_vec(2);
      mu_vec(0) = mu1;
      mu_vec(1) = mu2;
      double sigma1 = mlesigmahat1(ii,ii);
      double sigma2 = mlesigmahat1(jj,jj);
      // arma::vec vec1 = (data1 % data2)/pow(S_depth,2);
      // double sigmaint0 = log(sum(vec1)/(double)sample_size_use) - (mu1 + sigma1*0.5) - (mu2 + sigma2*0.5);
      double sigmaint0 = sigma0_mat(ii,jj);
      
      arma::mat sigma_matrix0(2,2);
      sigma_matrix0(0,0) = sigma1;
      sigma_matrix0(1,1) = sigma2;
      sigma_matrix0(1,0) = sigmaint0;
      sigma_matrix0(0,1) = sigmaint0;
      
      double det_sigma_matrix0 = det(sigma_matrix0);
      if(det_sigma_matrix0 < 0){
        sigmaint0 = sign(sigmaint0) * (sqrt(sigma1 * sigma2) * 0.99);
        sigma_matrix0(1,0) = sigmaint0;
        sigma_matrix0(0,1) = sigmaint0;
      }
      
      //
      mlesigmahat_res(ii,jj) = sigmaint0 - gradiant_int(ii,jj)/Hessian_int(ii,jj);
      //
      arma::mat sigma_matrix1(2,2);
      sigma_matrix1(0,0)=sigma1;
      sigma_matrix1(1,1)=sigma2;
      sigma_matrix1(1,0)=mlesigmahat_res(ii,jj);
      sigma_matrix1(0,1)=mlesigmahat_res(ii,jj);
      
      double det_sigma_matrix1 = det (sigma_matrix1);
      if ( det_sigma_matrix1<0 ){
        mlesigmahat_res(ii,jj)=sigmaint0;
      }
    }
  }
  omp_destroy_lock(&lock);
  //
  return(mlesigmahat_res);
}
