rm(list=ls())
setwd("~/two sample test")
source("R/matching.R")
source("R/hypot.R")
source("R/agg_functions.R")
source("R/ASE.R")
source("R/nlcm.R")
source("R/FPR_TPR.R")
library(nett)
library(ggplot2)

library(tidyverse)
library(igraph)
library(Matrix)

edge_list <- read_csv("data/COLLAB/COLLAB_A.txt", col_names = FALSE) %>% 
  as.matrix
graph_ind <- read_csv("data/COLLAB/COLLAB_graph_indicator.txt", col_names = FALSE)$X1
labels <- read_csv("data/COLLAB/COLLAB_graph_labels.txt", col_names = FALSE)$X1

n_graph <- max(graph_ind)

g <- graph_from_edgelist(edge_list, directed = FALSE) %>% simplify()
A_list <- lapply(1:n_graph, function(i) as_adj(induced_subgraph(g, graph_ind == i)))
class1 <- which(labels == 1)
class2 <- which(labels == 2)
class3 <- which(labels == 3)

################################################################################
rand_idx <- sample(which(labels == 3), 1)
A <- A_list[[rand_idx]]
gr = igraph::graph_from_adjacency_matrix(A, "undirected") 
par(mar = c(0,0,0,0))
out = nett::plot_net(gr)
################################################################################

set.seed(10000)
K <- 2
num_reps <- 200
threshold <- 0
sample_size <- 10
d <- 2
#Null
res <- do.call(rbind, lapply(1:num_reps, function(rep_id) {
  print(rep_id)
  sm1 <- sample(class2, sample_size)
  sm2 <- sample(class2, sample_size)
  print(sm1)
  print(sm2)
  A1list <- A_list[sm1] 
  A2list <- A_list[sm2]
  test <- agg_sbm_test(A1list, A2list, K, threshold)
  That <- test[1]
  TruncatedThat <- test[2]
  ase_dist <- agg_ase_dist(A1list, A2list, d, sigma = 1)
  nlcm_dist <- agg_nlcm_dist(A1list, A2list, d)
  
  return(data.frame(rep = rep_id, tstat = TruncatedThat, ase_stat = ase_dist, nlcm_stat = nlcm_dist))
}))

res_alt <- do.call(rbind, lapply(1:num_reps, function(rep_id) {
  print(rep_id)
  sm1 <- sample(class2, sample_size)
  sm2 <- sample(class3, sample_size)
  print(sm1)
  print(sm2)
  A1list <- A_list[sm1] 
  A2list <- A_list[sm2]
  
  test <- agg_sbm_test(A1list, A2list, K, threshold)
  That <- test[1]
  TruncatedThat <- test[2]
  ase_dist <- agg_ase_dist(A1list, A2list, d, sigma = 1)
  nlcm_dist <- agg_nlcm_dist(A1list, A2list, d)
  
  return(data.frame(rep = rep_id, tstat = TruncatedThat, ase_stat = ase_dist, nlcm_stat = nlcm_dist))
}))

#build ROC curve
Tstat_ROC <- get_fpr_tpr(res$tstat, res_alt$tstat)
NLCM_ROC <- get_fpr_tpr(res$nlcm_stat, res_alt$nlcm_stat)
ASE_ROC <- get_fpr_tpr(res$ase_stat, res_alt$ase_stat)

cbbPalette <- c("#56B4E9",  "#E69F00", "#000000", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")
ggplot(NULL) +
  geom_line(aes(x = Tstat_ROC$fpr, y = Tstat_ROC$tpr, colour = "SBM-TS"), 
            size = 2) +
  geom_line(aes(x = NLCM_ROC$fpr, y = NLCM_ROC$tpr, colour = "NCLM"), 
            size = 2) +
  geom_line(aes(x = ASE_ROC$fpr, y = ASE_ROC$tpr, colour = "ASE"),
            size = 2) +
  ggplot2::scale_colour_manual(values=cbbPalette)+
  ggplot2::theme_bw() +
  ggplot2::theme(text = ggplot2::element_text(size=18))+
  ggplot2::coord_fixed(ratio = 1) +
  ggplot2::geom_abline(intercept =0 , slope = 1, linetype="dashed") +
  ggplot2::scale_x_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::scale_y_continuous(limits = c(-0.01,1.01), expand = c(0,0)) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.7, 0.2),
    legend.text = ggplot2::element_text(size=15),
    text = ggplot2::element_text(size=16)
  ) +
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 4, keyheight = 1)
  ) +
  geom_abline() +
  xlab("FPR") +
  ylab("TPR")


################################################################################



