get_eig_repr = function(A, K) {
  RSpectra::eigs_sym(A, K)$vectors
}
align_ortho_mats = function(X, Y) {
  svd_res = svd(t(X) %*% Y)
  svd_res$v %*% t(svd_res$u)
}

align_Y_to_X = function(X, Y) {
  Q = align_ortho_mats(X, Y)
  Y %*% Q
}

recover_sign_matrix = function(P){
  L = ncol(P)
  # diag(as.vector(round(t(P)) %*% rep(1,L))) # recover the sign matrix
  # diag(as.vector(t(P) %*% rep(1,L))) # recover the sign matrix
  t(abs(P)) %*% P
}

# Algorithm 1
recover_Ut_list = function(Xlist) {
  L = ncol(Xlist[[1]])
  lapply(seq_along(Xlist), function(j) {
    out = kmeans(Xlist[[j]], L, nstart=25)
    Gam = out$centers
    Gam_svd = svd(Gam)
    St = recover_sign_matrix(Gam_svd$u)
    # Utt = round(Gam_svd$u) 
    # St = diag(as.vector(Utt %*% rep(1,L))) # recover the sign matrix
    Vtt = Gam_svd$v
    Ut = St %*% t(Vtt)  # recover U_t
    Ut
  })
}

# Algorithm 2
recover_Ubt_list = function(Xlist) {
  L = ncol(Xlist[[1]])
  lapply(seq_along(Xlist), function(j) {
    out = kmeans(Xlist[[j]], L, nstart=20)
    Gam = out$centers
    Gam_svd = svd(Gam)
    Utt = Gam_svd$u # round(Gam_svd$u)  # this is a signed permutation matrix
    t(abs(Utt)) %*% Gam
  })
}

get_labels_from_Xlist = function(Xlist, L) {
  n = nrow(Xlist[[1]])  
  xi_vec = kmeans(do.call(rbind, Xlist), L , nstart = 25)$cluster
  xi_list = split(xi_vec, ceiling(seq_along(xi_vec)/n)) # needs fixing for variable number of nodes
  xi_list
}

align_Xlist_by_Ulist = function(Xlist, Ulist) {
  Rlist = lapply(seq_along(Ulist), function(j) align_ortho_mats(Ulist[[1]], Ulist[[j]]))
  lapply(seq_along(Xlist), function(j)  Xlist[[j]] %*% Rlist[[j]])
}