source("R/data_gen.R")
source("R/competing_methods.R")
library(parallel)
library(ggplot2)
library(dplyr)
library(Matrix)
library(RcppHungarian)

matching_perm = function(zh, z) {
  n = length(z)
  out = HungarianSolver(1-compute_confusion_matrix(zh, z)/n)
  perm = out$pairs[,2]
  perm
}

comp_all_matching_perms = function(xih, xi_tru) {
  lapply(seq_along(xih), function(j) matching_perm(xih[[j]], xi_tru[[j]]))
}
comp_matching_score = function(xih, xi_tru){
  length(unique(comp_all_matching_perms(xih, xi_tru))) # / length(xih)
}

set.seed(1234)
n = 200
K = 2
L = 5
J = 48
lambda = 30
zeta = .3
gam = .1
nreps = 10
n_cores = 32
niter = 100

out = gen_rand_nsbm(n=n, K=K, L=L, J=J,  lambda=lambda, gam = gam, zeta=zeta, sort_z = T)
#out = generate_nathans_data(n = n, J = J)
A = out$A
z_tru = out$z
xi_tru = out$xi
eta = out$eta

res = NULL
# Without alignement
out = spec_net_clust(A, K, L)
xih = out$xi
res = rbind(res, data.frame(nmi = hsbm::get_slice_nmi(xih, xi_tru), method = "xi-NMI SC")) # good xi-NMI
res = rbind(res, data.frame(nmi = nett::compute_mutual_info(z_tru, out$z), method = "z-NMI SC no-align")) # poor z-NMI
image(out$Si)

# Use true connectivity matrices
Bh = lapply(z_tru, function(k) eta[[k]])
Si = make_conn_similarity_mat(Bh)
z = spec_clust(Si, K)
res = rbind(res, data.frame(nmi = nett::compute_mutual_info(z_tru, z), method = "z-NMI tru eta")) # perfect z-NMI
image(Si)


# Align to the label vector of the first network
align_labvec_list = function(labvec_list) {
  m = length(labvec_list)
  labvec_base = labvec_list[[1]]
  
  c(labvec_list[1], lapply(2:m, function(i) {
    labvec_i = labvec_list[[i]]
    perm = matching_perm(labvec_i, labvec_base)
    perm[labvec_i]
  }))
}
xih2 = align_labvec_list(xih)
Bh_list = estim_B_list(A, xih2)
Si = make_conn_similarity_mat(Bh_list)
z = spec_clust(Si, K)
res = rbind(res, data.frame(nmi = nett::compute_mutual_info(z_tru, z), method = "z-NMI align-to-1st")) # poor z-NMI
image(Si)


# Match to true labels
xih3 = lapply(seq_along(xi_tru), function(i) {
  perm = matching_perm(xih[[i]], xi_tru[[i]])
  perm[xih[[i]]]
})
Bh_list = estim_B_list(A, xih3)
Si = make_conn_similarity_mat(Bh_list)
z = spec_clust(Si, K)
res = rbind(res, data.frame(nmi = nett::compute_mutual_info(z_tru, z), method = "z-NMI align-to-tru")) # perfect z-NMI
image(Si)

# An attempt at joint SC
Asvd <- RSpectra::svds(do.call(rbind,A), L)
US <- Asvd$u %*% diag(Asvd$d)
# US <- Asvd$u 
kclust <- kmeans(US, L , nstart = 25)
xih_vec = kclust$cluster
xih4 = split(xih_vec, ceiling(seq_along(xih_vec)/n))
Bh_list = estim_B_list(A, xih4)
Si = make_conn_similarity_mat(Bh_list)
z = spec_clust(Si, K)
res = rbind(res, data.frame(nmi = nett::compute_mutual_info(z_tru, z), method = "z-NMI joint SC")) # poor z-NMI

print(knitr::kable(res %>% relocate(method)))

# z1 = xih[[4]]
# z0 = xih[[1]]
# perm = matching_perm(z1, z0)
# clue::cl_agreement(clue::as.cl_hard_partition(z1), clue::as.cl_hard_partition(z0), method = "diag")
# sum(diag(compute_confusion_matrix(perm[z1], z0)))/n