rm(list=ls())
setwd("~/two sample test")
source("R/matching.R")
source("R/hypot.R")
source("R/ASE.R")
source("R/roc_revisions.R")
source("R/FPR_TPR.R")
source("R/get_labels.R")
source("R/nclm.R")

library(nett)
library(Rcpp)
library(Matrix)
library(RcppArmadillo)
library(parallel)
library(ggplot2)

Rcpp::sourceCpp("src/sample_graphon_airoldi.cpp")
core_count <- detectCores() - 1

set.seed(1000)
n <- 10000
rho <- 0.05



num_reps <- 100
K <- 2
d_ase <- K
d_nclm <- 20
tau <- 0

# alternative 
eps <- 0.05
delta <- 0.2

# bcdc settings
alpha <- 0.7
beta <- 0.7
niter <- 100

res <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_graphon_airoldi(runif(n, 0, 1), rho)
  A2 <- sample_graphon_airoldi(runif(n, 0, 1), rho)
  
  #z1hat <- spec_clust(A1, K, niter = 50)
  #z2hat <- spec_clust(A2, K, niter = 50)
  
  z1hat <- get_labels(A1, K, alpha, beta, niter)
  z2hat <- get_labels(A2, K, alpha, beta, niter)

  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat, tau)
  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 1)$biased
  g1 <- log_moment(A1, d_nclm)
  g2 <- log_moment(A2, d_nclm)
  nclm <- l2norm_squared(g1 - g2)
  
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That,
                    ase_stat = ase_dist,
                    nclm_stat = nclm))
}, mc.cores = core_count))

res_alt <- do.call(rbind, mclapply(1:num_reps, function(rep_id) {
  A1 <- sample_graphon_airoldi(runif(n, 0, 1), rho)
  A2 <- sample_graphon_airoldi_alt(runif(n, 0, 1), rho, eps, delta)
  
  #z1hat <- spec_clust(A1, K, niter = 50)
  #z2hat <- spec_clust(A2, K, niter = 50)
  
  z1hat <- get_labels(A1, K, alpha, beta, niter)
  z2hat <- get_labels(A2, K, alpha, beta, niter)
  
  That <- sbm_tst_old(A1, A2, z1hat, z2hat, tau)

  ase_dist <- fast_mmd(ase(A1, d_ase), ase(A2, d_ase), nBasis = 2^9, sigvec = 1)$biased
  g1 <- log_moment(A1, d_nclm)
  g2 <- log_moment(A2, d_nclm)
  nclm <- l2norm_squared(g1 - g2)
  
  #log_progress()
  return(data.frame(rep = rep_id, 
                    sbmts_stat = That,
                    ase_stat = ase_dist,
                    nclm_stat = nclm))
}, mc.cores = core_count))

sbm_roc <- get_roc(res$sbmts_stat, res_alt$sbmts_stat)
ase_roc <- get_roc(res$ase_stat, res_alt$ase_stat)
nclm_roc <- get_roc(res$nclm_stat, res_alt$nclm_stat)


cbbPalette <- c("#56B4E9",  "#E69F00", "#000000", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")
ggplot(NULL) +
  geom_line(aes(x = sbm_roc$FPR, y = sbm_roc$TPR, colour = "SBM-TS"),  size = 2) +
  geom_line(aes(x = ase_roc$FPR, y = ase_roc$TPR, colour = "ASE"), size = 2) +
  geom_line(aes(x = nclm_roc$FPR, y = nclm_roc$TPR, colour = "NCLM"), size = 2) +
  ggplot2::scale_colour_manual(values=cbbPalette)+
  ggplot2::geom_line(size=2)  +
  ggplot2::theme_bw() +
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  ggplot2::coord_fixed(ratio = 1) +
  ggplot2::geom_abline(intercept =0 , slope = 1, linetype="dashed") +
  ggplot2::scale_x_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::scale_y_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.7, 0.2),
    legend.text = ggplot2::element_text(size=15),
    text = ggplot2::element_text(size=16)
  ) +
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 4, keyheight = 1)
  ) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR")



ggplot(res, aes(x = sbmts_stat)) + 
  geom_histogram(color="black", fill="white") +
  ggtitle(paste("Distribution of T under null"))

