library(parallel)
library(foreach)
library(dplyr)

source("R/data_gen.R")
source("R/competing_methods.R")
source("tests/utils.R")

n = 200
K = 2
L = 5
J = 48
zeta = 1
gam = .1
niter = 100
lamax = 40 
# average degree
lambdas <- round(10^seq(log10(10),log10(lamax), length.out = 10), 2)
nlam = length(lambdas)
runs <- expand.grid(lam_idx = 1:nlam, itr = 1:niter)


simulate_run <- function(j) {
  # set.seed(j)
  idx <- runs[j,"lam_idx"]
  itr <- runs[j,"itr"]
  lambda <- lambdas[idx]
  
  out = gen_rand_nsbm(n=n, K=K, L=L, J=J,  lambda=lambda, gam = gam, zeta=zeta, sort_z = T)
  A = out$A
  z_tru = out$z
  xi_tru = out$xi
  eta = out$eta
  nn = table(xi_tru[[1]])
  
  xih = spec_net_clust(A, K, L)$xi
  
  # spectral clustering and alignment
  Xlist = lapply(A, function(As) get_eig_repr(As, L))
  Utlist = recover_Ut_list(Xlist)
  Xlist_aligned = align_Xlist_by_Ulist(Xlist, Utlist)
  
  # network clustering
  xih = get_labels_from_Xlist(Xlist_aligned, L)
  Bh_list = estim_B_list(A, xih)
  Si = make_conn_similarity_mat(Bh_list)
  z = spec_clust(Si, K)

  # estimate connection matrix
  # works only for two clusters 
  n1 = table(z_tru)[1]
  Bsum1 =Reduce(`+`, lapply(1:n1, function(j) compute_block_sums(A[[j]], xih[[j]])))
  ns1 =Reduce(`+`, lapply(1:n1, function(j) {
    nsj =table(xih[[j]])
    nsj %*% t(nsj) - diag(nsj)
  }))
  etah1 =Bsum1/ns1
  
  Bsum2 =Reduce(`+`, lapply((n1+1):J, function(j) compute_block_sums(A[[j]], xih[[j]])))
  ns2 =Reduce(`+`, lapply((n1+1):J, function(j) {
    nsj =table(xih[[j]])
    nsj %*% t(nsj) - diag(nsj)
  }))
  etah2 = Bsum2/ns2
  
  # collecting all results
  data.frame(
    lambda = lambda,
    iter = itr,
    slice_nmi = hsbm::get_slice_nmi(xih, xi_tru),
    clust_nmi = nett::compute_mutual_info(z_tru, z),
    eta_diff_tru = norm(eta[[1]] - eta[[2]]),
    eta_diff_est = norm(etah1 - etah2),
    nn_diff = length(unique(nn)) == L
  )
}
CPU_CORES_TO_USE <- min(20, detectCores()-1)
result_raw <- do.call(rbind, mclapply(1:nrow(runs), simulate_run, mc.cores = CPU_CORES_TO_USE))
# delete rows with errors
result = subset(result_raw, subset = !grepl("Error", result_raw[, 1],  fixed = TRUE))
# result <- do.call(rbind, lapply(1:nrow(runs), simulate_run))

# number of networks with distinct cluster sizes
sum(as.logical(result$nn_diff)) 

boxplot(as.numeric(result$clust_nmi) ~ result$lambda, 
        xlab = "Average degree", 
        ylab = "Network clustering NMI")

boxplot(as.numeric(result$eta_diff_tru) ~ result$lambda, 
        xlab = "Average degree", 
        ylab = "True eta difference")

boxplot(as.numeric(result$eta_diff_est) ~ result$lambda, 
        xlab = "Average degree", 
        ylab = "Estimated eta difference")

boxplot((as.numeric(result$eta_diff_est) - as.numeric(result$eta_diff_tru)) ~ result$lambda, 
        xlab = "Average degree", 
        ylab = "Estimated eta difference")

boxplot(as.numeric(result$slice_nmi) ~ result$lambda, 
        xlab = "Average degree", 
        ylab = "Slice NMI")




clust_res <- result %>%
  mutate(clust_nmi = as.numeric(clust_nmi)) %>%
  group_by(lambda) %>%
  summarize(clust_nmi_mean = mean(clust_nmi),
            clust_nmi_sd = sd(clust_nmi),
            count = n())

clust_res %>%
ggplot(aes(x=as.numeric(lambda), y=clust_nmi_mean)) +
  geom_line()+
  geom_point(size = 5)+
  theme_bw()+
  theme(text = element_text(size=20))+
  geom_errorbar(aes(ymin=clust_nmi_mean-clust_nmi_sd, ymax=clust_nmi_mean+clust_nmi_sd), width=.2)+
  ylab("Network clustering NMI") + xlab("Average Degree")+
  guides(fill=guide_legend(keywidth=0.25,keyheight=0.25,default.unit="inch"))
