source("R/data_gen.R")
source("R/competing_methods.R")
library(parallel)
library(ggplot2)
library(dplyr)
library(Matrix)
library(RcppHungarian)

get_eig_repr = function(A, K) {
  RSpectra::eigs_sym(A, K)$vectors
}
align_ortho_mats = function(X, Y) {
  svd_res = svd(t(X) %*% Y)
  svd_res$v %*% t(svd_res$u)
}

align_Y_to_X = function(X, Y) {
  Q = align_ortho_mats(X, Y)
  Y %*% Q
}

recover_sign_matrix = function(P){
  L = ncol(P)
  # diag(as.vector(round(t(P)) %*% rep(1,L))) # recover the sign matrix
  # diag(as.vector(t(P) %*% rep(1,L))) # recover the sign matrix
  t(abs(P)) %*% P # this works
}

# Algorithm 1
recover_Ut_list = function(Xlist) {
  L = ncol(Xlist[[1]])
  lapply(seq_along(Xlist), function(j) {
    out = kmeans(Xlist[[j]], L, nstart=20)
    Gam = out$centers
    Gam_svd = svd(Gam)
    St = recover_sign_matrix(Gam_svd$u)
    # Utt = round(Gam_svd$u) 
    # St = diag(as.vector(Utt %*% rep(1,L))) # recover the sign matrix
    Vtt = Gam_svd$v
    Ut = St %*% t(Vtt)  # recover U_t
    Ut
  })
}

# Algorithm 2
recover_Ubt_list = function(Xlist) {
  L = ncol(Xlist[[1]])
  lapply(seq_along(Xlist), function(j) {
    out = kmeans(Xlist[[j]], L, nstart=20)
    Gam = out$centers
    Gam_svd = svd(Gam)
    Utt = Gam_svd$u # round(Gam_svd$u)  # this is a signed permutation matrix
    t(abs(Utt)) %*% Gam
  })
}

align_Ulist_to_first = function(Ulist) {
  lapply(seq_along(Ulist), function(j) align_Y_to_X(Ulist[[1]], Ulist[[j]]))
}

get_alignmat_Ulist_to_first = function(Ulist) {
  lapply(seq_along(Ulist), function(j) align_ortho_mats(Ulist[[1]], Ulist[[j]]))
}

align_Xlist_by_Ulist = function(Xlist, Ulist) {
  Rlist = lapply(seq_along(Ulist), function(j) align_ortho_mats(Ulist[[1]], Ulist[[j]]))
  lapply(seq_along(Xlist), function(j)  Xlist[[j]] %*% Rlist[[j]])
}

err_Ulist_to_first = function(Ulist) {
  sapply(seq_along(Ulist), function(j) norm(Ulist[[1]]- Ulist[[j]]))
}

set.seed(123)
n = 200
K = 2
L = 5
J = 48
lambda = 25
zeta = .3
gam = .1
niter = 100

out = gen_rand_nsbm(n=n, K=K, L=L, J=J,  lambda=lambda, gam = gam, zeta=zeta, sort_z = T)
#out = generate_nathans_data(n = n, J = J)
A = out$A
z_tru = out$z
xi_tru = out$xi
eta = out$eta

nn = table(xi_tru[[1]])
(N = diag(nn))
perm = order(nn, decreasing = T)

P = lapply(seq_along(xi_tru), function(j) {
  zlab = z_tru[[j]]
  Z = label_vec2mat(xi_tru[[j]])
  Z %*% eta[[zlab]] %*% t(Z)
})


Xlist = lapply(A, function(As) get_eig_repr(As, L))  # apply eigen rep to ideal matrix P = E[A]

Utlist = recover_Ut_list(Xlist)
Ubtlist = recover_Ubt_list(Xlist)

Utlist_aligned = align_Ulist_to_first(Utlist)
Utlist_aligned
err_Ulist_to_first(Utlist_aligned)

Ubtlist_aligned =  align_Ulist_to_first(Ubtlist)
Ubtlist_aligned
err_Ulist_to_first(Ubtlist_aligned)

# Rt_list = get_alignmat_Ulist_to_first(Utlist)
Xlist_aligned = align_Xlist_by_Ulist(Xlist, Utlist)

get_labels_from_Xlist = function(Xlist, L) {
  n = nrow(Xlist[[1]])  
  xi_vec = kmeans(do.call(rbind, Xlist), L , nstart = 25)$cluster
  xi_list = split(xi_vec, ceiling(seq_along(xi_vec)/n)) # needs fixing for variable number of nodes
  xi_list
}

xih = get_labels_from_Xlist(Xlist_aligned, L)
# xih = get_labels_from_Xlist(Xlist, L)
lapply(xih, table)
Bh_list = estim_B_list(A, xih)
Si = make_conn_similarity_mat(Bh_list)
z = spec_clust(Si, K)
print( data.frame(nmi = nett::compute_mutual_info(z_tru, z), method = "z-NMI new alignement") ) # poor z-NMI 
image(Si)

