setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")
source("R/hypot.R")
library(nett)
library(dplyr)
library(tidyr)
test2 <- function(B, eps, num_iter, n){
  K <- nrow(B)
  Beps <- B + diag(eps, K)
  alpha <- 0.05
  threshold <- qchisq(alpha, df = K*(K+1)/2, lower.tail = FALSE)
  tp <- 0
  fp <- 0
  tn <- 0
  fn <- 0
  That_arr <- rep(0, num_iter)
  for (iter in 1:num_iter){
    if (iter %% 50 == 0){
      print(iter)
    }
    z1 <- sample(1:K, n, replace = TRUE, prob = c(0.4, 0.6))
    z2 <- sample(1:K, n, replace = TRUE, prob = c(0.4, 0.6))
    A1 <- sample_dcsbm(z1, B)
    A2 <- sample_dcsbm(z2, B)
    
    z1hat <- spec_clust(A1, K, niter = 30)
    z2hat <- spec_clust(A2, K, niter = 30)
    
    That_arr[iter] <- sbm_tst_old(A1, A2, z1hat, z2hat)
  }

  tp <- sum(That_arr < threshold)
  fn <- num_iter - tp
  #Alternative
  for (iter in 1:num_iter){
    z1 <- sample(1:K, n, replace = TRUE, prob = c(0.4, 0.6))
    z2 <- sample(1:K, n, replace = TRUE, prob = c(0.4, 0.6))
    A1 <- sample_dcsbm(z1, B)
    A2 <- sample_dcsbm(z2, Beps)
    
    z1hat <- spec_clust(A1, K, niter = 30)
    z2hat <- spec_clust(A2, K, niter = 30)

    That <- sbm_tst_old(A1, A2, z1hat, z2hat)
    if (That < threshold){
      fp <- fp + 1
    } else {
      tn <- tn + 1
    }
  }
  return(c(tp, fn, tn, fp, eps))
  
}

B <- matrix(c(0.5, 0.2, 0.2, 0.5),2, 2)

set.seed(1000)
eps_arr <- c(0.01, 0.02, 0.05, 0.1)
n_arr <- c(100, 200, 500, 1000)
results <- data.frame(matrix(ncol = 6, nrow = 0))
for (i in 1:length(eps_arr)){
  for (j in 1:length(n_arr)){
    temp <- c(test2(B, eps_arr[i], num_iter = 1000, n_arr[j]), n_arr[j])
    results <- rbind(results, temp)
  }
}
colnames(results) <-  c("tp", "fn", "tn", "fp", "eps", "n")
results2 <- as_tibble(results)

res<- results2 %>% 
  mutate(tpr = tp/(tp + fn), p_est = 1 - fp/(tn + fp)) %>% 
  select(eps, n, tpr, p_est) %>% 
  pivot_wider(names_from = eps, values_from = c(tpr, p_est), names_sep = ", eps = ") %>% 
  select(c(1,2,6,3,7,4,8, 5, 9))
