setwd("~/two sample test")
source("R/accuracy.R")
source("R/matching.R")
source("R/generate.R")
source("R/checks.R")

library(dplyr)
test1 <- function(B, eps, num_iter, n){
  K <- nrow(B)
  Beps <- B + diag(eps, K)
  alpha <- 0.05
  threshold <- qchisq(alpha, df = K*(K+1)/2, lower.tail = FALSE)
  tp <- 0
  fp <- 0
  tn <- 0
  fn <- 0
  
  for (iter in 1:num_iter){
    if (iter %% 50 == 0){
      print(iter)
    }
    z1 <- sample(1:K, n, replace = TRUE, prob = c(0.4, 0.6))
    z2 <- sample(1:K, n, replace = TRUE, prob = c(0.4, 0.6))
    A1 <- sample_dcsbm(z1, B)
    A2 <- sample_dcsbm(z2, Beps)
    
    z1hat <- spec_clust(A1, K, niter = 30)
    z2hat <- spec_clust(A2, K, niter = 30)
    
    B1hat <- estim_dcsbm(A1, z1hat)$B
    B2hat <- estim_dcsbm(A2, z2hat)$B
    
    Pt <- matching(B1hat, B2hat)
    
    That <- get_That(A1, A2, z1hat, z2hat, Pt)
    if (That < threshold){
      fp <- fp + 1
    } else {
      tn <- tn + 1
    }
    
    
    z1 <- sample(1:K, n, replace = TRUE, prob = c(0.4, 0.6))
    z2 <- sample(1:K, n, replace = TRUE, prob = c(0.4, 0.6))
    A1 <- sample_dcsbm(z1, B)
    A2 <- sample_dcsbm(z2, B)
    
    z1hat <- spec_clust(A1, K, niter = 30)
    z2hat <- spec_clust(A2, K, niter = 30)
    
    B1hat <- estim_dcsbm(A1, z1hat)$B
    B2hat <- estim_dcsbm(A2, z2hat)$B
    
    Pt <- matching(B1hat, B2hat)
    
    That <- get_That(A1, A2, z1hat, z2hat, Pt)
    
    if (That < threshold){
      tp <- tp +1
    }
    else{
      fn <- fn + 1
    }
  }
  return(c(tp, fn, tn, fp, eps))
  
}

B <- matrix(c(0.5, 0.2, 0.2, 0.5),2, 2)

set.seed(1000)
eps_arr <- c(0.01, 0.02, 0.05, 0.1)
n_arr <- c(100, 200, 500, 1000)
results <- data.frame(matrix(ncol = 6, nrow = 0))
for (i in 1:length(eps_arr)){
  for (j in 1:length(n_arr)){
    temp <- c(test1(B, eps_arr[i], num_iter = 1000, n_arr[j]), n_arr[j])
    results <- rbind(results, temp)
  }
}
colnames(results) <-  c("tp", "fn", "tn", "fp", "eps", "n")
results

results <- as_tibble(results)
cr <- results %>% 
  mutate(power = tn/(tn + fp)) %>% 
  select(eps, n, power) %>% 
  print(n = Inf) %>% 
  pivot_wider(names_from = eps, values_from = power)
