rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/agg_functions.R")
library(nett)
library(ggplot2)
# https://chrsmrrs.github.io/datasets/docs/datasets/
#(1) 	DS_A.txt (m lines) 
#sparse (block diagonal) adjacency matrix for all graphs,
#each line corresponds to (row, col) resp. (node_id, node_id)

A <- read.csv("~/two sample test/datasets/DD/DD_A.txt", header=FALSE)

#(2) 	DS_graph_indicator.txt (n lines)
#column vector of graph identifiers for all nodes of all graphs,
#the value in the i-th line is the graph_id of the node with node_id i

graph_indicator <- read.table("~/two sample test/datasets/DD/DD_graph_indicator.txt", quote="\"", comment.char="")$V1

#(3) 	DS_graph_labels.txt (N lines) 
#class labels for all graphs in the dataset,
#the value in the i-th line is the class label of the graph with graph_id i
graph_labels <- read.table("~/two sample test/datasets/DD/DD_graph_labels.txt", quote="\"", comment.char="")$V1


Amat <- Matrix::sparseMatrix(A$V1, A$V2, x = rep(1, length(A$V1)))
num_graphs <- length(graph_labels)

n <- (rle(graph_indicator)$lengths)
A <- vector("list", num_graphs)

start_idx <- 1
end_idx <- n[1]

for (i in 1:num_graphs){
  A[[i]] <- Amat[start_idx:end_idx, start_idx:end_idx]
  start_idx <- end_idx + 1
  end_idx <- end_idx + n[i+1]
}

A[[1]]

class1 <- which(graph_labels == 1)
class2 <- which(graph_labels == 2)

set.seed(100)
K <- 6
num_iter <- 150
threshold <- 0.1

sample_sizes <- c(10, 20, 50, 100)
results <- data.frame(matrix(ncol = 3, nrow = 0))

for (sample_size in sample_sizes){
  That <- rep(0, num_iter)
  TruncatedThat <- rep(0, num_iter)
  for (iter in 1:num_iter){
    if (iter %% 10 == 0){
      print(iter)
    }
    sm1 <- sample(class1, sample_size)
    sm2 <- sample(class1, sample_size)
    A1list <- A[sm1] 
    A2list <- A[sm2]
    
    test <- agg_sbm_test(A1list, A2list, K, threshold)
    That[iter] <- test[1]
    TruncatedThat[iter] <- test[2]
  }
  
  
  ThatAlt <- rep(0, num_iter)
  TruncatedThatAlt <- rep(0, num_iter)
  for (iter in 1:num_iter){
    if (iter %% 10 == 0){
      print(iter)
    }
    sm1 <- sample(class1, sample_size)
    sm2 <- sample(class2, sample_size)
    A1list <- A[sm1] 
    A2list <- A[sm2]
    
    test <- agg_sbm_test(A1list, A2list, K, threshold)
    ThatAlt[iter] <- test[1]
    TruncatedThatAlt[iter] <- test[2]
  }
  
  TruncatedThat <- sort(TruncatedThat)
  TruncatedThatAlt <- sort(TruncatedThatAlt)
  
  tpr <- seq(1, num_iter, 1)/num_iter
  fpr <- rep(0, num_iter)
  for(i in 1:num_iter){
    fpr[i] <- sum(TruncatedThatAlt < TruncatedThat[i])/num_iter
  }
  temp <- cbind(tpr, fpr, sample_size)
  results <- rbind(results, temp)
}

colnames(results) <-  c("tpr", "fpr", "n_sample")
results$n_sample <- as.factor(results$n_sample)

p<- ggplot(results, aes(x = fpr, y = tpr, color = n_sample)) + 
  geom_line() +
  geom_abline(intercept = 0, slope = 1) +
  xlim(0, 1) +
  ylim(0, 1) 

