rm(list=ls())
setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
source("R/ASE.R")
library(nett)
library(ggplot2)

# https://chrsmrrs.github.io/datasets/docs/datasets/
#(1) 	DS_A.txt (m lines) 
#sparse (block diagonal) adjacency matrix for all graphs,
#each line corresponds to (row, col) resp. (node_id, node_id)

A <- read.csv("~/two sample test/datasets/DD/DD_A.txt", header=FALSE)

#(2) 	DS_graph_indicator.txt (n lines)
#column vector of graph identifiers for all nodes of all graphs,
#the value in the i-th line is the graph_id of the node with node_id i

graph_indicator <- read.table("~/two sample test/datasets/DD/DD_graph_indicator.txt", quote="\"", comment.char="")$V1

#(3) 	DS_graph_labels.txt (N lines) 
#class labels for all graphs in the dataset,
#the value in the i-th line is the class label of the graph with graph_id i
graph_labels <- read.table("~/two sample test/datasets/DD/DD_graph_labels.txt", quote="\"", comment.char="")$V1


Amat <- Matrix::sparseMatrix(A$V1, A$V2, x = rep(1, length(A$V1)))
num_graphs <- length(graph_labels)

n <- (rle(graph_indicator)$lengths)
A <- vector("list", num_graphs)

start_idx <- 1
end_idx <- n[1]

for (i in 1:num_graphs){
  A[[i]] <- Amat[start_idx:end_idx, start_idx:end_idx]
  start_idx <- end_idx + 1
  end_idx <- end_idx + n[i+1]
}

A[[1]]

class1 <- which(graph_labels == 1)
class2 <- which(graph_labels == 2)


Uhat <- rep(0, num_iter)
num_iter <- 150
for (iter in 1:num_iter){
  if (iter %% 10 == 0){
    print(iter)
  }
  sm1 <- sample(class1, 1)
  sm2 <- sample(class1, 1)
  A1 <- A[[sm1]]
  A2 <- A[[sm2]]
  Ase1 <- ase(A1, d = 5)
  Ase2 <- ase(A2, d = 5)
  Uhat[iter] <- u_stat(Ase1, Ase2)
}

UhatAlt <- rep(0, num_iter)
for (iter in 1:num_iter){
  if (iter %% 10 == 0){
    print(iter)
  }
  sm1 <- sample(class1, 1)
  sm2 <- sample(class2, 1)
  A1 <- A[[sm1]]
  A2 <- A[[sm2]]
  Ase1 <- ase(A1, d = 5)
  Ase2 <- ase(A2, d = 5)
  UhatAlt[iter] <- u_stat(Ase1, Ase2)
}

Uhat <- sort(Uhat)
UhatAlt <- sort(UhatAlt)

Utpr <- seq(1, num_iter, 1)/num_iter
Ufpr <- rep(0, num_iter)
for(i in 1:num_iter){
  Ufpr[i] <- sum(UhatAlt < Uhat[i])/num_iter
}

ggplot(NULL, aes(x = Ufpr, y = Utpr)) + 
  geom_line() +
  geom_abline(intercept = 0, slope = 1) +
  xlim(0, 1) +
  ylim(0, 1) 

Udf <- as.data.frame(cbind(tpr, fpr, 0))
