#NCLM
rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/agg_functions.R")
source("R/ASE.R")
library(nett)
library(ggplot2)
library(igraph)

l2norm_squared = function(X) Re(sum(Conj(t(X)) %*% X))

log_moment <-function(A, k){
  res <- c()
  o <- eigen(A)
  temp <- 1
  temp_n <- 1
  n <- nrow(A)
  for (i in 1:k){
    temp <- temp * o$values
    temp_n <- temp_n * n
    res[i] <- sum(temp)/temp_n
  }
  return((res))
}


nlcm_test <- function(B1, sigma, num_iter, n){
  K <- nrow(B1)
  B2 <- B1 + rnorm_symmetric_matrix(K, 0, sigma)
  That_null <- rep(0, num_iter)
  nlcm_null <- rep(0, num_iter)
  
  #Null
  for (iter in 1:num_iter){
    if (iter %% 50 == 0){
      print(iter)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B1)
    
    z1hat <- spec_clust(A1, K)
    z2hat <- spec_clust(A2, K)
    
    That_null[iter] <- sbm_tst_old(A1, A2, z1hat, z2hat)
    g1 <- log_moment(A1, 5)
    g2 <- log_moment(A2, 5)
    nlcm_null[iter] <- l2norm_squared(g1 - g2)
  }
  
  That_alt <- rep(0, num_iter)
  nlcm_alt <- rep(0, num_iter)
  
  #Alternative
  for (iter in 1:num_iter){
    if (iter %% 50 == 0){
      print(iter)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B2)
    
    z1hat <- spec_clust(A1, K, niter = 30)
    z2hat <- spec_clust(A2, K, niter = 30)
    
    That_alt[iter] <- sbm_tst_old(A1, A2, z1hat, z2hat)
    
    g1 <- log_moment(A1, 5)
    g2 <- log_moment(A2, 5)
    nlcm_alt[iter] <- l2norm_squared(g1 - g2)
  }
  
  That_null <- sort(That_null)
  That_alt <- sort(That_alt)
  
  nlcm_null <- sort(nlcm_null)
  nlcm_alt <- sort(nlcm_alt)
  
  tpr <- seq(1, num_iter, 1)/num_iter
  fpr <- rep(0, num_iter)
  nfpr <- rep(0, num_iter)
  for(i in 1:num_iter){
    fpr[i] <- sum(That_alt < That_null[i])/num_iter
    nfpr[i] <- sum(nlcm_alt < nlcm_null[i])/num_iter
  }
  r1 <- cbind(as.data.frame(cbind(tpr, fpr, sigma)), FALSE)
  r2 <- cbind(as.data.frame(cbind(tpr, nfpr, sigma)), TRUE)
  names(r1) <- c("tpr", "fpr", "sigma", "nlcm")
  names(r2) <- c("tpr", "fpr", "sigma", "nlcm")
  res <- rbind(r1, r2)
  res$sigma <- as.factor(res$sigma)
  return(res)
}

set.seed(1000)
K <- 2
B1 <- runif_symmetric_matrix(K, 0.2, 0.7)
n_sigma <- 5
sigma <- seq(0.001, 0.01, length.out = n_sigma)
num_iter <- 100
n <- 1000

results <- data.frame(matrix(ncol = 4, nrow = 0))


for (i in 1:n_sigma){
  temp <- nlcm_test(B1, sigma[i], num_iter, n)
  results <- rbind(results, temp)
}

ggplot(results, aes(x = fpr, y = tpr))+
  geom_line(aes(linetype = nlcm, color = sigma)) +
  geom_abline(intercept = 0, slope = 1) +
  xlim(0, 1) +
  ylim(0, 1) +
  ggtitle(paste("n = ", n, "K = ", K))





