source("R/matching.R")
source("R/generate.R")
source("R/get_labels.R")
library(Matrix)
library(RSpectra)


library(nett)
library(Rcpp)
library(RcppArmadillo)
library(parallel)

library(ggplot2)

simple_cv <- function(S, A_cv, minK = 2, maxK = 10){
  sapply(minK:maxK, function(K){
    z_cv <- get_labels(A_cv, K, 0.4, 0.4, 100)
    #z_cv <- spec_clust(A_cv, K)
    #B_cv <- estim_dcsbm(A_cv, z_cv)$B
    B_cv <- block_sums(A_cv, z_cv)/block_sizes_upd(z_cv, K)
    sum(sapply(1:dim(S)[1], function(i){
      - log(B_cv[z_cv[S[i,1]], z_cv[S[i,2]]])
    }))
  })
}

block_sizes_upd <- function(z, K){
  ns <- sapply(1:K, function(k){
    sum(z == k)
  })
  
  temp = ns %*% t(ns)
  diag(temp) = (diag(temp) - ns)/2
  temp
}

modified_A <- function(A, n_edges_cv = 150){
  # choose at random a subset of edges to remove
  edge_set <- which(A == 1, arr.ind = T)
  S <- edge_set[sample(sum(A), n_edges_cv, replace = FALSE),]
  
  # create a modified adjacency matrix for cv
  A_cv <- A
  for(i in 1:n_edges_cv){
    A_cv[S[i,1], S[i, 2]] = 0
    A_cv[S[i,2], S[i, 1]] = 0
  }
  list(A_cv, S)
}
