rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/ASE.R")
source("R/roc_revisions.R")
source("R/FPR_TPR.R")
library(nett)
library(Rcpp)
library(Matrix)
library(RcppArmadillo)
library(parallel)
library(ggplot2)

Rcpp::sourceCpp("src/sample_graphon.cpp")
core_count <- detectCores() - 1

num_anc_points = 10
d = 1
set.seed(0)
# Generate pairs (x_i, y_i) from Normal
x <- matrix(nrow = num_anc_points, ncol = d)
y <- matrix(nrow = num_anc_points, ncol = d)
for(j in 1:num_anc_points){
  x[j,] <- rnorm(d, 0, 1)
  y[j,] <- rnorm(d, 0, 1)
}
# Generate alpha_i
alpha <- rnorm(num_anc_points, 0 , 1)
n <- 1000
num_reps <- 100
K <- 2
d_ase <- 10
rho <- 0.1
bw <- 1
tau <- 0

res <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_graphon(n, x, y, alpha, bw, rho)
  A2 <- sample_graphon(n, x, y, alpha, bw, rho)
  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat, tau)
  That_v2 <- sbm_tst_v2(A1, A2, z1hat, z2hat, tau)
  # if (That > 100){
  #   saveRDS(A1, paste0("bad_graphon_data/A1_",rep_id,".RDS"))
  #   saveRDS(A2, paste0("bad_graphon_data/A2_",rep_id,".RDS"))
  # }
  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 1)$biased
  
  #log_progress()
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That,
                    sbmts_v2_stat = That_v2,
                    ase_stat = ase_dist))
}, mc.cores = core_count))

eps <- rnorm(num_anc_points, 0, 1)
res_alt <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_graphon(n, x, y, alpha, bw, rho)
  A2 <- sample_graphon(n, x, y, alpha+eps, bw, rho)
  
  z1hat <- spec_clust(A1, K, niter = 50)
  z2hat <- spec_clust(A2, K, niter = 50)
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat, tau)
  That_v2 <- sbm_tst_v2(A1, A2, z1hat, z2hat, tau)

  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 1)$biased
  
  #log_progress()
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That,
                    sbmts_v2_stat = That_v2,
                    ase_stat = ase_dist))
}, mc.cores = core_count))


sbm_roc <- get_roc(res$sbmts_stat, res_alt$sbmts_stat)
sbm_v2_roc <- get_roc(res$sbmts_v2_stat, res_alt$sbmts_v2_stat)
ase_roc <- get_roc(res$ase_stat, res_alt$ase_stat)


ggplot(NULL) +
  geom_line(aes(x = sbm_roc$FPR, y = sbm_roc$TPR, colour = "SBM Test")) +
  geom_line(aes(x = sbm_v2_roc$FPR, y = sbm_v2_roc$TPR, colour = "SBM_v2 Test")) +
  geom_line(aes(x = ase_roc$FPR, y = ase_roc$TPR, colour = "ASE")) +
  xlim(0,1) + 
  ylim(0,1) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR") +
  ggtitle(paste("Graphon Testing ROC"))
 

ggplot(res, aes(x = sbmts_stat)) + 
  geom_histogram(color="black", fill="white") +
  ggtitle(paste("Distribution of T under null"))
