rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/agg_functions.R")
source("R/ASE.R")
source("R/FPR_TPR.R")
library(nett)
library(parallel)

core_count <- detectCores() - 1

l2norm_squared = function(X) Re(sum(Conj(t(X)) %*% X))

seed <- 999
rho <- 0.1
K <- 3
sigma <- 0.005
num_reps <- 100
d_ase <- K
d_nlcm <- 20
J_arr <- c(2, 3, 4, 5, 10, 15, 20)
bw <- c(0.01, 0.1, 1, 10, 100)

n <- 10000

for (i in 1:50){
  set.seed(seed + i)
  
  B1 <- runif_symmetric_matrix(K, 0.2, 0.7)
  B2 <- B1 + rnorm_symmetric_matrix(K, 0, sigma)
  
  
  B1 <- rho * B1
  B2 <- rho * B2
  
  #Null
  res <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
    if (rep_id %% 5 == 0){
      print(rep_id)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B1)
    
    z1hat <- spec_clust(A1, K, niter = 50)
    z2hat <- spec_clust(A2, K, niter = 50)
    
    
    That <- sbm_tst_old(A1, A2, z1hat, z2hat)
    g1 <- log_moment(A1, d_nlcm)
    g2 <- log_moment(A2, d_nlcm)
    nlcm_dist <- sapply(1:length(J_arr), function(i){
      l2norm_squared(g1[1:J_arr[i]] - g2[1:J_arr[i]])
    })
    
    ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = bw)$biased
    
    return(data.frame(seed = seed+i, 
                      rep = rep_id, 
                      tstat = That,
                      bw = t(ase_dist), 
                      J_arr = t(nlcm_dist)
                      ))
  }, mc.cores = core_count))
  
  #Alternative
  res_alt <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
    if (rep_id %% 5 == 0){
      print(rep_id)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B2)
    
    z1hat <- spec_clust(A1, K, niter = 50)
    z2hat <- spec_clust(A2, K, niter = 50)
    
    That <- sbm_tst_old(A1, A2, z1hat, z2hat)
    g1 <- log_moment(A1, d_nlcm)
    g2 <- log_moment(A2, d_nlcm)
    nlcm_dist <- sapply(1:length(J_arr), function(i){
      l2norm_squared(g1[1:J_arr[i]] - g2[1:J_arr[i]])
    })
    
    ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = bw)$biased
    
    return(data.frame(seed = seed+i, rep = rep_id, 
                      tstat = That, 
                      bw = t(ase_dist), 
                      J_arr = t(nlcm_dist)
                      ))
  }, mc.cores = core_count))
  
  
  
  ################################################################################
  
  
  write.table(res, "~/two sample test/res/K3/null.csv",
              append = TRUE,
              sep = ",",
              col.names = FALSE,
              row.names = FALSE,
              quote = FALSE)
  
  write.table(res_alt, "~/two sample test/res/K3/alt.csv",
              append = TRUE,
              sep = ",",
              col.names = FALSE,
              row.names = FALSE,
              quote = FALSE)
  
}
