rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/agg_functions.R")
source("R/ASE.R")
source("R/FPR_TPR.R")
library(nett)
library(ggplot2)
library(igraph)

l2norm_squared = function(X) Re(sum(Conj(t(X)) %*% X))

seed <- 999
rho <- 0.1
K <- 20
sigma <- 0.005
num_reps <- 100
d_ase <- K
d_nlcm <- K
n <- 10000

for (i in 1:10){
  set.seed(seed + i)
  
  B1 <- runif_symmetric_matrix(K, 0.2, 0.7)
  B2 <- B1 + rnorm_symmetric_matrix(K, 0, sigma)
  
  
  B1 <- rho * B1
  B2 <- rho * B2
  
  #Null
  res <- do.call(rbind, lapply(1:num_reps, function(rep_id) {
    if (rep_id %% 5 == 0){
      print(rep_id)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B1)
    
    z1hat <- spec_clust(A1, K, niter = 50)
    z2hat <- spec_clust(A2, K, niter = 50)
    
    mis1 <- misclassification_rate(z1, z1hat)
    mis2 <- misclassification_rate(z2, z2hat)
    
    That <- sbm_tst_old(A1, A2, z1hat, z2hat)
    #OracleThat <- sbm_tst_old(A1, A2, z1, z2)
    g1 <- log_moment(A1, d_nlcm)
    g2 <- log_moment(A2, d_nlcm)
    nlcm_dist <- l2norm_squared(g1 - g2)
    ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 1)$biased
    
    return(data.frame(seed = seed+i, rep = rep_id, mis1, mis2, tstat = That,
                      ase_stat = ase_dist, nlcm_stat = nlcm_dist))
  }))
  
  #Alternative
  res_alt <- do.call(rbind, lapply(1:num_reps, function(rep_id) {
    if (rep_id %% 5 == 0){
      print(rep_id)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B2)
    
    z1hat <- spec_clust(A1, K, niter = 50)
    z2hat <- spec_clust(A2, K, niter = 50)
    
    
    mis1 <- misclassification_rate(z1, z1hat)
    mis2 <- misclassification_rate(z2, z2hat)
    
    That <- sbm_tst_old(A1, A2, z1hat, z2hat)
    #OracleThat <- sbm_tst_old(A1, A2, z1, z2)
    g1 <- log_moment(A1, d_nlcm)
    g2 <- log_moment(A2, d_nlcm)
    nlcm_dist <- l2norm_squared(g1 - g2)
    ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), sigvec = 1)$biased
    
    return(data.frame(seed = seed+i, rep = rep_id, mis1, mis2, tstat = That, 
                      ase_stat = ase_dist, nlcm_stat = nlcm_dist))
  }))
  
  
  
  ################################################################################
  
  
  write.table(res, "~/two sample test/res/K20/sparse/null.csv",
              append = TRUE,
              sep = ",",
              col.names = FALSE,
              row.names = FALSE,
              quote = FALSE)
  
  write.table(res_alt, "~/two sample test/res/K20/sparse/alt.csv",
              append = TRUE,
              sep = ",",
              col.names = FALSE,
              row.names = FALSE,
              quote = FALSE)
  
}



