rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/agg_functions.R")
source("R/ASE.R")
library(nett)
library(ggplot2)
# https://chrsmrrs.github.io/datasets/docs/datasets/
#(1) 	DS_A.txt (m lines) 
#sparse (block diagonal) adjacency matrix for all graphs,
#each line corresponds to (row, col) resp. (node_id, node_id)

A <- read.csv("~/two sample test/datasets/DD/DD_A.txt", header=FALSE)

#(2) 	DS_graph_indicator.txt (n lines)
#column vector of graph identifiers for all nodes of all graphs,
#the value in the i-th line is the graph_id of the node with node_id i

graph_indicator <- read.table("~/two sample test/datasets/DD/DD_graph_indicator.txt", quote="\"", comment.char="")$V1

#(3) 	DS_graph_labels.txt (N lines) 
#class labels for all graphs in the dataset,
#the value in the i-th line is the class label of the graph with graph_id i
graph_labels <- read.table("~/two sample test/datasets/DD/DD_graph_labels.txt", quote="\"", comment.char="")$V1

#node_labels <- read.table("~/two sample test/datasets/DDY/DD_node_labels.txt", quote="\"", comment.char="")

Amat <- Matrix::sparseMatrix(A$V1, A$V2, x = rep(1, length(A$V1)))
num_graphs <- length(graph_labels)

n <- (rle(graph_indicator)$lengths)
A <- vector("list", num_graphs)

start_idx <- 1
end_idx <- n[1]

for (i in 1:num_graphs){
  A[[i]] <- Amat[start_idx:end_idx, start_idx:end_idx]
  start_idx <- end_idx + 1
  end_idx <- end_idx + n[i+1]
}


class1 <- which(graph_labels == 1)
class2 <- which(graph_labels == 2)

set.seed(100)
K <- 6
num_iter <- 150
threshold <- 0.1

sample_sizes <- c(10)
results <- data.frame(matrix(ncol = 4, nrow = 0))
dresults <- data.frame(matrix(ncol = 4, nrow = 0))
d <- 5

for (sample_size in sample_sizes){
  That <- rep(0, num_iter)
  TruncatedThat <- rep(0, num_iter)
  dist <- rep(0, num_iter)
  for (iter in 1:num_iter){
    if (iter %% 10 == 0){
      print(iter)
    }
    sm1 <- sample(class1, sample_size)
    sm2 <- sample(class1, sample_size)
    A1list <- A[sm1] 
    A2list <- A[sm2]
    
    test <- agg_sbm_test(A1list, A2list, K, threshold)
    That[iter] <- test[1]
    TruncatedThat[iter] <- test[2]
    
    Ase1list <- lapply(A1list, ase, d)
    Ase2list <- lapply(A2list, ase, d)
    
    avg_dist <- 0 
    for (i in 1:sample_size){
      for (j in 1:sample_size){
        avg_dist <- avg_dist + fast_mmd(Ase1list[[i]], Ase2list[[j]], d)$biased
      }
    }
    dist[iter] <- avg_dist/sample_size^2
  }
  
  
  ThatAlt <- rep(0, num_iter)
  TruncatedThatAlt <- rep(0, num_iter)
  distAlt <- rep(0, num_iter)
  for (iter in 1:num_iter){
    if (iter %% 10 == 0){
      print(iter)
    }
    sm1 <- sample(class1, sample_size)
    sm2 <- sample(class2, sample_size)
    A1list <- A[sm1] 
    A2list <- A[sm2]
    
    test <- agg_sbm_test(A1list, A2list, K, threshold)
    ThatAlt[iter] <- test[1]
    TruncatedThatAlt[iter] <- test[2]
    
    Ase1list <- lapply(A1list, ase, d)
    Ase2list <- lapply(A2list, ase, d)
    
    avg_dist <- 0 
    for (i in 1:sample_size){
      for (j in 1:sample_size){
        avg_dist <- avg_dist + fast_mmd(Ase1list[[i]], Ase2list[[j]], d)$biased
      }
    }
    distAlt[iter] <- avg_dist/sample_size^2
  }
  
  TruncatedThat <- sort(TruncatedThat)
  TruncatedThatAlt <- sort(TruncatedThatAlt)
  
  
  dist <- sort(dist)
  distAlt <- sort(distAlt)
  
  tpr <- seq(1, num_iter, 1)/num_iter
  fpr <- rep(0, num_iter)
  dtpr <- seq(1, num_iter, 1)/num_iter
  dfpr <- rep(0, num_iter)
  for(i in 1:num_iter){
    fpr[i] <- sum(TruncatedThatAlt < TruncatedThat[i])/num_iter
    dfpr[i] <- sum(distAlt < dist[i])/num_iter
  }
  temp <- cbind(tpr, fpr, sample_size)
  results <- rbind(results, temp)
  
  dtemp <- cbind(dtpr, dfpr, sample_size)
  dresults <- rbind(dresults, dtemp)
}

dresults
results 


ggplot(dresults, aes(x = dfpr, y = dtpr))+
  geom_line() +
  geom_abline(intercept = 0, slope = 1) +
  xlim(0, 1) +
  ylim(0, 1)


ggplot(results, aes(x = fpr, y = tpr))+
  geom_line() +
  geom_abline(intercept = 0, slope = 1) +
  xlim(0, 1) +
  ylim(0, 1)