library('RSpectra')
library("TopicScore")
library('tidyverse')
library('trimcluster')
library('extremefit')
library('igraph')
library(doParallel)
library(R.matlab)
set.seed(10007)



## SCORE+
SCOREplus <- function(A, k, c = 0.1, r = NULL){
  
  
  # if r not give, set to be k+1
  if (is.null(r)){
    fix.latent.dim = T
    r = k #+ 1
  } else {
    fix.latent.dim = T # otherwise latent dimension is given and fixed
  }
  
  n = nrow(A) # number of nodes
  degrees = rowSums(A)
  delta = c * max( degrees) # tunning parameter for graph laplacian
  d.inv = 1 / sqrt( delta + degrees )
  
  L.delta = t(d.inv * A) * d.inv # graph laplacian with ridge regularization
  
  # get top r eigenvectors
  eig.out = RSpectra::eigs(L.delta, k = r)
  eig.vec.w = eig.out$vectors %*% diag(eig.out$values) # reweight eigenvectors by eigen values
  
  # get ratio matrix 
  ratios = eig.vec.w[,2:r] / eig.vec.w[,1]
  
  if(!fix.latent.dim){
    # decide latent dimension by eigen-gap
    signal.weakness = 1 - eig.out$values[k+1] / eig.out$values[k]
    if ( signal.weakness > 0.1 ){
      ratios = ratios[,1:(k-1)]
    }
  }
  
  
  # k-means
  labels = kmeans(ratios, k, nstart = 100, iter.max = 100)$cluster
  
  return(list(labels = labels,
              ratios = ratios,
              delta = delta,
              eig.vec = eig.out$vectors,
              eig.val = eig.out$values))
}


#initialization
name <- "simmons" #when running the code on Caltech or Polblogs, change the name accordingly
data_mat <- readMat(paste(name, ".mat", sep = ''))
D_tot <- data_mat$A
label_tot <- data_mat$label
n_tot <- length(label_tot)


fac_list <- unique(label_tot)
k <- length(fac_list)

Pi_tot <- matrix(rep(0, n_tot * k), nrow = n_tot)
for(i in c(1:n_tot)){
  for(j in c(1:k)){
    if(label_tot[i] == fac_list[j]){
      Pi_tot[i, j] <- 1
    }
  }
}

perm_tot <- sample(n_tot, n_tot)



# generate permutation list of [K]
perm <- function(v) {
  nn <- length(v)
  if (nn == 1) v
  else {
    X <- NULL
    for (i in 1:nn) X <- rbind(X, cbind(v[i], perm(v[-i])))
    X
  }
}
permk <- perm(1:k)



#define structural similarity metric between two vectors
cor2 <- function(x, y){
  if(sum(x * y) == 0){
    0
  }
  else{
    sum(x * y) / sqrt(sum(x * x)) / sqrt(sum(y * y))
  }
}



#main code

cl <- makeCluster(7)
registerDoParallel(cl)

fold_num <- 10
fold_size <- floor(n_tot / fold_num)
semi_ratio <- c(3, 5, 7) / 10 
ratio_len <- length(semi_ratio)



res_AngleMinPlus <-   rep(0, fold_num) %o% rep(0, ratio_len) #AngleMin+
res_SNMF <- rep(0, fold_num) %o% rep(0, ratio_len) #SNMF
res_u <- rep(0, fold_num) %o% rep(0, ratio_len) #Unsupervised (SCORE+)


time_AngleMinPlus <-   rep(0, fold_num) %o% rep(0, ratio_len) #AngleMin+
time_SNMF <- rep(0, fold_num) %o% rep(0, ratio_len) #SNMF
time_u <- rep(0, fold_num) %o% rep(0, ratio_len) #Unsupervised (SCORE+)



for(ii in c(1:(fold_num))){
  # Generating the traing and test data
  test_ind_s <- (ii - 1) * fold_size + 1
  test_ind_t <- ii * fold_size
  if(i == fold_num){
    test_ind_t <- n_tot
  }
  train_ind <- perm_tot[-c(test_ind_s:test_ind_t)]
  test_ind <- perm_tot[c(test_ind_s:test_ind_t)]
  
  perm_list_train <- sample(train_ind, length(train_ind))
  n <- length(train_ind)
  n_test <- length(test_ind)
  
  
  #Unsupervised (SCORE+)
  
  u_do <- function(jj){
    train_u_ind <- c(train_ind, jj)
    center_comu <- rep(0, nrow(permk))
    clu_res_label <-  SCOREplus(D_tot[train_u_ind, train_u_ind], k)[[1]]
    for(i in c(1:nrow(permk))){
      for(j in c(1:(n + 1))){
        center_comu[i] <- center_comu[i] + 1 - (permk[i, clu_res_label[j]] == which.max(Pi_tot[train_u_ind[j], ]))
      }
    }
    best_perm <- which.min(center_comu)
    1 - (permk[best_perm, clu_res_label[n + 1]] == which.max(Pi_tot[jj, ])) 
  }
  temp_u_err <- foreach(jjj = test_ind, .combine='c') %dopar% u_do(jjj)
  res_u[ii, ] <- rep(sum(temp_u_err) / n_test, ratio_len)
  
  cur.time <- time.time()
  SCOREplus(D_tot, k)
  time_u[ii, ] <- time.time() - cur.time
  
  
  for(ratio_num in c(1:ratio_len)){
    n_L <- floor(n * semi_ratio[ratio_num])
    n_U <- n - n_L
    
    
    #AngleMin+
    cur.time <- time.time()
    temp_res <- SCOREplus(D_tot[perm_list_train[(n_L + 1):n], perm_list_train[(n_L + 1):n]], k)
    clu_res_label <- temp_res[[1]]
    
    Pi_est_AngleMinPlus <- matrix(rep(0, 2 * n_tot * k), nrow = n_tot)
    Pi_est_AngleMinPlus[perm_list_train[1:n_L], 1:k] <- Pi_tot[perm_list_train[1:n_L], ]
    for(i in c(1:n_U)){
      Pi_est_AngleMinPlus[perm_list_train[i + n_L], k + clu_res_label[i]] <- 1
    } 
    
    V_AngleMinPlus <- t(Pi_est_AngleMinPlus[perm_list_train[1:n_L], ]) %*% D_tot[perm_list_train[1:n_L], ] %*%  Pi_est_AngleMinPlus 
    AngleMinPlus_do <- function(i){
      tempwi2 <- as.vector(D_tot[i, ] %*% Pi_est_AngleMinPlus)
      if(sum(tempwi2^2) ==0){
        rbinom(n = 1, size = 1, prob = 1 - 1 / k)
      }
      else{
        temp <- rep(0, k)
        for(l in c(1:k)){
          temp[l] <- cor2(tempwi2, V_AngleMinPlus[l, ])
        }
        1 - Pi_tot[i, which.max(temp)]
      }
    }
    err_AngleMinPlus <- foreach(jjj = test_ind, .combine='c') %dopar% AngleMinPlus_do(jjj)
    res_AngleMinPlus[ii, ratio_num] <- sum(err_AngleMinPlus) / n_test
    time_AngleMinPlus[ii, ] <- time.time() - cur.time
    
    #SNMF
    cur.time <- time.time()
    
    
    n_test_SNMF <- n_test 
    
    SNMF_do <- function(jj){
      test_h_ind <- c(perm_list_train, jj)
      O_L <- matrix(rep(0, (n + 1) * (n + 1)), nrow = n + 1)
      O_L[1:n_L, 1:n_L] <- Pi_tot[perm_list_train[1:n_L],] %*% t(Pi_tot[perm_list_train[1:n_L], ])
      D_L <- diag(rowSums(O_L))
      
      H <- matrix(1 + abs(rt(n = (n + 1) * k, df = 5)), nrow = n + 1)
      H <- diag(rowSums(H)^{-1}) %*% H
      lambda <- 1 
      
      for(iii in c(1:20)){
        H <- H * (D_tot[test_h_ind, test_h_ind] %*% H + 2 * lambda * O_L %*% H) / (H %*% t(H) %*% H + lambda * D_L %*% H + 0.1)
      }
      
      
      if(any(is.na(H))){ 
        print('NA detected!')
        rbinom(n = 1, size = 1, prob = 1 - 1 / k)
      }
      else{
        center_comu_h <- rep(0, nrow(permk))
        for(i in c(1:nrow(permk))){
          for(j in c(1:(n + 1))){
            center_comu_h[i] <- center_comu_h[i] + 1 - (permk[i, which.max(H[j, ])] == which.max(Pi_tot[test_h_ind[j], ]))
          }
        }
        best_perm_h <- which.min(center_comu_h)
        
        1 - (permk[best_perm_h, which.max(H[n + 1, ])] == which.max(Pi_tot[jj, ]))
      }
    }
    temp_SNMF_err <- foreach(jjj = test_ind, .combine='c') %dopar% SNMF_do(jjj)
    
    res_semi[ii, ratio_num] <- sum(temp_SNMF_err) / n_test_SNMF
    
    time_SNMF[ii, ] <- time.time() - cur.time
  } 
}
stopCluster(cl)

#save the data
file_path_save <- paste('Real_', name, '.Rdata', sep = '')
save.image(file_path_save)