setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
library(nett)
library(ggplot2)
library(dplyr)

oracletest <- function(B1, sigma, num_iter, n){
  K <- nrow(B1)
  B2 <- B1 + rnorm_symmetric_matrix(K, 0, sigma)
  That_null <- rep(0, num_iter)
  OracleThat_null <- rep(0, num_iter)

  #Null
  for (iter in 1:num_iter){
    if (iter %% 50 == 0){
      print(iter)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B1)
    
    z1hat <- spec_clust(A1, K)
    z2hat <- spec_clust(A2, K)
    
    That_null[iter] <- sbm_tst_old(A1, A2, z1hat, z2hat)
    OracleThat_null[iter] <- sbm_tst_old(A1, A2, z1, z2)
  }
  
  That_alt <- rep(0, num_iter)
  OracleThat_alt <- rep(0, num_iter)
  
  #Alternative
  for (iter in 1:num_iter){
    if (iter %% 50 == 0){
      print(iter)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B2)
    
    z1hat <- spec_clust(A1, K, niter = 30)
    z2hat <- spec_clust(A2, K, niter = 30)
    
    That_alt[iter] <- sbm_tst_old(A1, A2, z1hat, z2hat)
    OracleThat_alt[iter] <- sbm_tst_old(A1, A2, z1, z2)
  }

  That_null <- sort(That_null)
  That_alt <- sort(That_alt)
  
  OracleThat_null <- sort(OracleThat_null)
  OracleThat_alt <- sort(OracleThat_alt)
  
  tpr <- seq(1, num_iter, 1)/num_iter
  fpr <- rep(0, num_iter)
  ofpr <- rep(0, num_iter)
  for(i in 1:num_iter){
    fpr[i] <- sum(That_alt < That_null[i])/num_iter
    ofpr[i] <- sum(OracleThat_alt < OracleThat_null[i])/num_iter
  }
  r1 <- cbind(as.data.frame(cbind(tpr, fpr, sigma)), FALSE)
  r2 <- cbind(as.data.frame(cbind(tpr, ofpr, sigma)), TRUE)
  names(r1) <- c("tpr", "fpr", "sigma", "oracle")
  names(r2) <- c("tpr", "fpr", "sigma", "oracle")
  res <- rbind(r1, r2)
  res$sigma <- as.factor(res$sigma)
  return(res)
}

set.seed(1001)
K <- 7
B1 <- runif_symmetric_matrix(K, 0.2, 0.7)
n_sigma <- 5
sigma <- seq(0.001, 0.01, length.out = n_sigma)
num_iter <- 100
n <- 1000
  
results <- data.frame(matrix(ncol = 4, nrow = 0))
colnames(results) <-  c("tpr", "fpr", "sigma", "oracle")

for (i in 1:n_sigma){
  temp <- oracletest(B1, sigma[i], num_iter, n)
  results <- rbind(results, temp)
}

ggplot(results, aes(x = fpr, y = tpr))+
  geom_line(aes(linetype = oracle, color = sigma)) +
  geom_abline(intercept = 0, slope = 1) +
  xlim(0, 1) +
  ylim(0, 1) +
  ggtitle(paste("n = ", n, "K = ", K))



