rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/agg_functions.R")
source("R/ASE.R")
source("R/FPR_TPR.R")
library(nett)
library(ggplot2)
library(igraph)

l2norm_squared = function(X) Re(sum(Conj(t(X)) %*% X))

set.seed(1000)
rho <- 0.1
K <- 20
sigma <- 0.005
B1 <- runif_symmetric_matrix(K, 0.2, 0.7)
B2 <- B1 + rnorm_symmetric_matrix(K, 0, sigma)
num_reps <- 100
d <- 5
n <- 10000

B1 <- rho * B1
B2 <- rho * B2

#Null
res <- do.call(rbind, lapply(1:num_reps, function(rep_id) {
  if (rep_id %% 5 == 0){
    print(rep_id)
  }
  z1 <- sample(1:K, n, replace = TRUE)
  z2 <- sample(1:K, n, replace = TRUE)
  A1 <- sample_dcsbm(z1, B1)
  A2 <- sample_dcsbm(z2, B1)
  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  mis1 <- misclassification_rate(z1, z1hat)
  mis2 <- misclassification_rate(z2, z2hat)
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat)
  #OracleThat <- sbm_tst_old(A1, A2, z1, z2)
  g1 <- log_moment(A1, 5)
  g2 <- log_moment(A2, 5)
  nlcm_dist <- l2norm_squared(g1 - g2)
  ase_dist <- fast_mmd(ase(A1, d), ase(A2, d), sigvec = 1)$biased
  
  return(data.frame(rep = rep_id, mis1, mis2, tstat = That, ase_stat = ase_dist, nlcm_stat = nlcm_dist))
}))

#Alternative
res_alt <- do.call(rbind, lapply(1:num_reps, function(rep_id) {
  if (rep_id %% 5 == 0){
    print(rep_id)
  }
  z1 <- sample(1:K, n, replace = TRUE)
  z2 <- sample(1:K, n, replace = TRUE)
  A1 <- sample_dcsbm(z1, B1)
  A2 <- sample_dcsbm(z2, B2)
  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  
  mis1 <- misclassification_rate(z1, z1hat)
  mis2 <- misclassification_rate(z2, z2hat)
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat)
  #OracleThat <- sbm_tst_old(A1, A2, z1, z2)
  g1 <- log_moment(A1, 5)
  g2 <- log_moment(A2, 5)
  nlcm_dist <- l2norm_squared(g1 - g2)
  ase_dist <- fast_mmd(ase(A1, d), ase(A2, d), sigvec = 1)$biased
  
  return(data.frame(rep = rep_id, mis1, mis2, tstat = That, ase_stat = ase_dist, nlcm_stat = nlcm_dist))
}))

################################################################################
ggplot(res, aes(x = tstat)) + 
  geom_histogram(bins = 10, fill = "blue", alpha = 0.5) +
  xlab("T under null") + ylab("Frequency") +
  ggtitle(paste("Distribution of test statistic \n", "n = ", n, ", K = ", K))

ggplot(res_alt, aes(x = tstat)) + 
  geom_histogram(bins = 10, fill = "blue", alpha = 0.5) +
  xlab("T under alternative") + ylab("Frequency") +
  ggtitle(paste("Distribution of test statistic \n", "n = ", n, ", K = ", K))

ggplot(res, aes(x = ase_stat)) + 
  geom_histogram(bins = 10, fill = "blue", alpha = 0.5) +
  xlab("T under null") + ylab("Frequency") +
  ggtitle(paste("Distribution of test statistic \n", "n = ", n, ", K = ", K))

ggplot(res_alt, aes(x = ase_stat)) + 
  geom_histogram(bins = 10, fill = "blue", alpha = 0.5) +
  xlab("T under alternative") + ylab("Frequency") +
  ggtitle(paste("Distribution of test statistic \n", "n = ", n, ", K = ", K))

ggplot(res, aes(x = nlcm_stat)) + 
  geom_histogram(bins = 10, fill = "blue", alpha = 0.5) +
  xlab("T under null") + ylab("Frequency") +
  ggtitle(paste("Distribution of test statistic \n", "n = ", n, ", K = ", K))

ggplot(res_alt, aes(x = nlcm_stat)) + 
  geom_histogram(bins = 10, fill = "blue", alpha = 0.5) +
  xlab("T under alternative") + ylab("Frequency") +
  ggtitle(paste("Distribution of test statistic \n", "n = ", n, ", K = ", K))

#build ROC curve
#Oracle_Tstat_ROC <- get_fpr_tpr(res$oracle_tstat, res_alt$oracle_tstat)
Tstat_ROC <- get_fpr_tpr(res$tstat, res_alt$tstat)
NLCM_ROC <- get_fpr_tpr(res$nlcm_stat, res_alt$nlcm_stat)
ASE_ROC <- get_fpr_tpr(res$ase_stat, res_alt$ase_stat)

avg_mis <- mean(res$mis1) + mean(res$mis2)

avg_mis_alt <- mean(res_alt$mis1) + mean(res_alt$mis2)

miscl_rate <- round((avg_mis + avg_mis_alt)/2, 3)

ggplot(NULL) +
  geom_line(aes(x = Tstat_ROC$fpr, y = Tstat_ROC$tpr, colour = "SBM Test")) +
  geom_line(aes(x = NLCM_ROC$fpr, y = NLCM_ROC$tpr, colour = "NLCM")) +
  geom_line(aes(x = ASE_ROC$fpr, y = ASE_ROC$tpr, colour = "ASE")) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR") +
  ggtitle(paste("n = ", n, ", K = ", K, ", eps = ", sigma, ", miscl_rate = ", miscl_rate)) +
  coord_cartesian(xlim = c(0,1), ylim = c(0,1), clip = "off")
 
################################################################################
spec_clust2 <- function(A, K, type="lap",
                       tau = 0.25, nstart = 20, niter = 10,
                       ignore_first_col = FALSE) {
  U = spec_repr(A, K, type = type, tau = tau, ignore_first_col = ignore_first_col)
  return(kmeans(U, K, nstart = nstart, iter.max = niter, algorithm = "MacQueen")$cluster)
}

## K = 8, if rho < 0.08 miscl_rate is bad ~ 0.32 -> 0.8
set.seed(1000)
rho <- 0.1
K <- 20
sigma <- 0.005
B1 <- runif_symmetric_matrix(K, 0.2, 0.7)
B2 <- B1 + rnorm_symmetric_matrix(K, 0, sigma)
num_reps <- 100
d <- 5
n <- 10000

B1 <- rho * B1
B2 <- rho * B2

z1 <- sample(1:K, n, replace = TRUE)
z2 <- sample(1:K, n, replace = TRUE)
A1 <- sample_dcsbm(z1, B1)
A2 <- sample_dcsbm(z2, B1)

z1hat <- spec_clust(A1, K, niter = 100)
z2hat <- spec_clust(A2, K, niter = 100)

mis1 <- misclassification_rate(z1, z1hat)
mis2 <- misclassification_rate(z2, z2hat)

That <- sbm_tst_old(A1, A2, z1hat, z2hat)


g1 <- log_moment(A1, 5)
ase_dist <- fast_mmd(ase(A1, d), ase(A2, d), sigvec = 1)$biased



