setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
library(nett)
library(ggplot2)

roc_curve <- function(B1, sigma, num_iter, num_thr, n, labels = FALSE){
  K <- nrow(B1)
  B2 <- B1 + rnorm_symmetric_matrix(K, 0, sigma)
  tpr <- seq(0, 1, length.out = num_thr)
  That_null <- rep(0, num_iter)
  fp <- rep(0, num_thr)
  tn <- rep(0, num_thr)
  
  #Null
  for (iter in 1:num_iter){
    if (iter %% 50 == 0){
      print(iter)
    }
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B1)
    
    if (labels == TRUE){
      z1hat <- z1
      z2hat <- z2
    } else{
      z1hat <- spec_clust(A1, K, niter = 30)
      z2hat <- spec_clust(A2, K, niter = 30)
    }
    That_null[iter] <- sbm_tst_old(A1, A2, z1hat, z2hat)
  }
  
  thresholds <- quantile(That_null, probs = tpr)
  
  #Alternative
  for (iter in 1:num_iter){
    z1 <- sample(1:K, n, replace = TRUE)
    z2 <- sample(1:K, n, replace = TRUE)
    A1 <- sample_dcsbm(z1, B1)
    A2 <- sample_dcsbm(z2, B2)
    
    if (labels == TRUE){
      z1hat <- z1
      z2hat <- z2
    } else {
      z1hat <- spec_clust(A1, K, niter = 30)
      z2hat <- spec_clust(A2, K, niter = 30)
    }
    
    That <- sbm_tst_old(A1, A2, z1hat, z2hat)
    
    for (i in 1:num_thr){
      if (That < thresholds[i]){
        fp[i] <- fp[i] +1
      } else {
        tn[i] <- tn[i] + 1
      }
    }
  }
  #True positive, False positive
  fpr <- fp / (tn + fp)
  return(cbind(fpr, tpr, sigma))
}


