rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/agg_functions.R")
source("R/ASE.R")
source("R/nlcm.R")
source("R/FPR_TPR.R")
library(nett)
library(ggplot2)

library(tidyverse)
library(igraph)
library(Matrix)

# remotes::install_github("schochastics/networkdata")
G <- c(networkdata::starwars[1:6],
       networkdata::got
)
n_graph <- length(G)
A_list <- lapply(1:n_graph, function(i) as_adj(G[[i]]))


labels <- c(rep(1,6), rep(2,7))

class1 <- which(labels == 1)
class2 <- which(labels == 2)

################################################################################

set.seed(100)
K <- 5
num_reps <- 200
threshold <- 0.001
d <- 5

################################################################################

res <- do.call(rbind, lapply(1:num_reps, function(rep_id) {
  print(rep_id)
  idx <- sample(class1, 2, replace = FALSE)
  print(idx)
  A1 <- A_list[[idx[1]]] 
  A2 <- A_list[[idx[2]]]
  z1 <- spec_clust(A1, K)
  z2 <- spec_clust(A2, K)
  test <- sbm_tst_old(A1, A2, z1, z2)

  ase_dist <- fast_mmd(ase(A1, d), ase(A2,d), sigvec = 1)$biased
  nlcm_dist <- l2norm_squared(log_moment(A1, d) - log_moment(A2, d))
  
  return(data.frame(rep = rep_id, tstat = test, ase_stat = ase_dist, nlcm_stat = nlcm_dist))
}))

################################################################################

res_alt <- do.call(rbind, lapply(1:num_reps, function(rep_id) {
  print(rep_id)
  idx1 <- sample(class1, 1)
  idx2 <- sample(class2, 1)
  A1 <- A_list[[idx1]] 
  A2 <- A_list[[idx2]]
  z1 <- spec_clust(A1, K)
  z2 <- spec_clust(A2, K)
  test <- sbm_tst_old(A1, A2, z1, z2)
  
  ase_dist <- fast_mmd(ase(A1, d), ase(A2,d), sigvec = 1)$biased
  nlcm_dist <- l2norm_squared(log_moment(A1, d) - log_moment(A2, d))
  
  return(data.frame(rep = rep_id, tstat = test, ase_stat = ase_dist, nlcm_stat = nlcm_dist))
}))

hist(res$tstat,  breaks = seq(from=0, to=150, by=10))
hist(res_alt$tstat, breaks = seq(from=0, to=2000, by=10))

################################################################################

#build ROC curve
tstat <- unique(res$tstat)
alt_tstat <- unique(res_alt$tstat)
ase_stat <- unique(res$ase_stat)
alt_ase_stat <- unique(res_alt$ase_stat)
nlcm_stat <- unique(res$nlcm_stat)
alt_nlcm_stat <- unique(res_alt$nlcm_stat)


Tstat_ROC <- get_fpr_tpr(res$tstat, res_alt$tstat)
NLCM_ROC <- get_fpr_tpr(res$nlcm_stat, res_alt$nlcm_stat)
ASE_ROC <- get_fpr_tpr(res$ase_stat, res_alt$ase_stat)

ggplot(NULL) +
  geom_line(aes(x = Tstat_ROC$fpr, y = Tstat_ROC$tpr, colour = "SBM Test")) +
  geom_line(aes(x = NLCM_ROC$fpr, y = NLCM_ROC$tpr, colour = "NLCM")) +
  geom_line(aes(x = ASE_ROC$fpr, y = ASE_ROC$tpr, colour = "ASE")) +
  xlim(0,1) + 
  ylim(0,1) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR") +
  ggtitle(paste("SW vs GOT"))
