# SBM Two-Sample Test
sbm_tst <- function(A1, A2, z1, z2, tau) {
  S1 <- block_sums(A1, z1)
  S2 <- block_sums(A2, z2)
  
  m1 <- block_sizes(z1)
  m2 <- block_sizes(z2)
  
  B1hat <- S1/m1
  B2hat <- S2/m2
  
  # No step 3, 4 needed since N_r = 1
  
  Pt <- matching(B1hat, B2hat)
  perm <- mat2perm(Pt)
  S2p <- S2[perm, perm]
  m2p <- m2[perm, perm]
  
  Bhat <- (S1 + S2p) / (m1 + m2p)
  Sigma2_hat <- Bhat * (1 - Bhat)
  
  mbar <- 1 / (1 / m1 + 1 / m2p)
  
  mat <- mbar / Sigma2_hat * (S1 / m1 - S2p / m2p)^2
  upper_tri_mask <- upper.tri(mat, diag = TRUE)
  
  That <- sum(mat[upper_tri_mask & Sigma2_hat > tau], na.rm = TRUE)
  return(That/2)
}


# 
sbm_tst_old <- function(A1, A2, z1, z2, tau = 0) {
  B1hat <- estim_dcsbm(A1, z1)$B
  B2hat <- estim_dcsbm(A2, z2)$B
  
  Pt <- matching(B1hat, B2hat)
  zmap <- mat2perm(t(Pt))
  
  z1hat_aligned <- zmap[z1]
  S1 <- block_sums(A1, z1hat_aligned)
  S2 <- block_sums(A2, z2)
  
  m1 <- block_sizes(z1hat_aligned)
  m2 <- block_sizes(z2)
  
  Bhat <- (S1 + S2) / (m1 + m2)
  Sigma2_hat <- Bhat * (1 - Bhat)
  
  mbar <- 1 / (1 / m1 + 1 / m2)
  
  mat <- mbar / Sigma2_hat * (S1 / m1 - S2 / m2)^2
  upper_tri_mask <- upper.tri(mat, diag = TRUE)

  That <- sum(mat[upper_tri_mask & Sigma2_hat > tau], na.rm = TRUE)
  return(That/2)
}
