setwd("~/two sample test")
source("R/matching.R")
source("R/generate.R")
library(Matrix)
library(RSpectra)

block_sizes_cv <- function(z, n1, n2, K){
  ns <- matrix(0, K, K)
  for (i in 1:K){
    for (j in 1:K){
      if (i == j){
        ns[i, i] <- (n1[i] - 1)*n1[i]/2 + n1[i] *n2[i]
      }
      else {
        ns[i, j] <- n1[i]*(n1[j] + n2[j])
      }
    }
  }
  ns
}

loglikelihood_loss <- function(A, K, tol = 1e-5){
  n <- nrow(A)
  S1 <- 1:(n/2)
  S2 <- (n/2 + 1):n
  
  A11 <- A[S1, S1]
  A12 <- A[S1, S2]
  A22 <- A[S2, S2]
  
  A1 <- A[S1, ]
  
  A1_svd <- svds(A1, K) 
  z <- kmeans(A1_svd$v, K)$cluster
  
  N1 <- as.vector(table(z[S1]))
  N2 <- as.vector(table(z[S2]))
  
  sA1 = Matrix::summary(A1)
  temp1 = Matrix::sparseMatrix(i = z[sA1$i], j = z[sA1$j], x = sA1$x)
  
  sA11 = Matrix::summary(A11)
  temp11 = Matrix::sparseMatrix(i = z[sA11$i], j = z[sA11$j], x = sA11$x)
  
  # accounting for the "quirk" of summary function, where summary counts symmetric
  # entries twice as opposed to not symmetric entries, which leads to overcounting 
  # of intra-community edges
  temp <- temp1 - diag(diag(temp11), K, K)/2
  Bh <- temp/block_sizes_cv(z, N1, N2, K)
  Bh <- pmin(pmax(Bh, tol),  1 - tol)
  block_sums_A22 <- block_sums(A22, z[S2])
  block_sizes_A22 <- block_sizes(z[S2])
  
  
  ll <- block_sums_A22 * log(Bh) + (block_sizes_A22 - block_sums_A22)*log(1 - Bh)
  upper_tri_mask <- upper.tri(ll, diag = TRUE)
  -sum(ll[upper_tri_mask])
}


