library(parallel)
library(pbmcapply)
library(ggplot2)
library(dplyr)
library(Matrix)
library(stats)
# library(RcppHungarian)
library(CVXR)
library(rlang)
library(lpSolve)
library(nett)


norm_vec <- function(x) (sum(x^2))

rand_perm_mat = function(K) {
  eye = diag(K)
  perm = sample(K,K)
  eye[, perm]
}

# solving with lpSolve
solve_lin_assign_lpsolve = function(C) {
  lp = lp.assign(t(C), direction = "max")
  return(lp$solution)
}

solve_lin_assign = function(C) {
  K = nrow(C)
  P = Variable(K,K)
  const = list(P >= 0, sum_entries(P,1) == 1, sum_entries(P,2) == 1)
  obj = Maximize(matrix_trace(C %*% P))
  prob = Problem(obj, const)
  result = solve(prob, verbose = TRUE, num_iter = 100000)
  result$getValue(P)
}

# min_S |(PQ)S - C|
solve_s = function(PQ, C) {
  K = nrow(C)
  S = rep(0, K)
  for (i in 1:K){
    if (norm_vec(PQ[,i] - C[,i]) >= norm_vec(PQ[,i] + C[,i])){
      S[i] = -1
    } else  {
      S[i] = 1
    }
  }
  Matrix::diag(S)
}

#try random init for S if stuck in local minima
testaltmin_SBM = function(K, max_iter, n){
  B1 = matrix(runif(K^2),K)
  B1  = (B1 + t(B1))/2
  Ps = rand_perm_mat(K)
  B2 = Ps %*% B1 %*% t(Ps)
  
  # sample n x n SBM w/ B1
  z1 = sample(1:K, n, replace = T)
  A1 = sample_dcsbm(z1, B1)
  z1_h = spec_clust(A1, K)
  B1_h = estim_dcsbm(A1, z1_h)$B
  
  # sample n x n SBM w/ B2
  z2 = sample(1:K, n, replace = T)
  A2 = sample_dcsbm(z2, B2)
  z2_h = spec_clust(A2, K)
  B2_h = estim_dcsbm(A2, z2_h)$B
  
  o1 = eigen(B1_h)
  o2 = eigen(B2_h)
  Q1 = o1$vectors 
  Q2 = o2$vectors  
  
  S = diag(sign(runif(K, -1, 1)))
  # find Ps and Ss
  
  cost = 0
  P_sol = matrix(0,K,K)
  P = matrix(0,K,K)
  
  for (i in 1:max_iter){
    P = solve_lin_assign_lpsolve(Q1 %*% S %*% t(Q2))
    # S = solve_s2(Q1, Q2)
    S = solve_s(P %*% Q1, Q2)
    cost = norm(P %*% Q1 %*% S - Q2, type = "F")
  }
  P_sol = P
  # try diff init point for S if stuck in local minimum
  if (cost != 0){
    S = diag(sign(runif(K, -1, 1)))
    # find Ps and Ss
    
    for (i in 1:max_iter){
      P = solve_lin_assign_lpsolve(Q1 %*% S %*% t(Q2))
      P = round(P)
      # S = solve_s2(Q1, Q2)
      S = solve_s(P %*% Q1, Q2)
    }
    # if cost function improved, save new solution
    if (cost >= norm(P %*% Q1 %*% S - Q2, type = "F")){
      cost = norm(P %*% Q1 %*% S - Q2, type = "F")
      P_sol = P
    }
  }
  
  # try diff init point for S if stuck in local minimum
  if (cost != 0){
    S = diag(sign(runif(K, -1, 1)))
    # find Ps and Ss
    
    for (i in 1:max_iter){
      P = solve_lin_assign_lpsolve(Q1 %*% S %*% t(Q2))
      P = round(P)
      # S = solve_s2(Q1, Q2)
      S = solve_s(P %*% Q1, Q2)
    }
    # if cost function improved, save new solution
    if (cost >= norm(P %*% Q1 %*% S - Q2, type = "F")){
      cost = norm(P %*% Q1 %*% S - Q2, type = "F")
      P_sol = P
    }
  }
  
  # try diff init point for S if stuck in local minimum
  if (cost != 0){
    S = diag(sign(runif(K, -1, 1)))
    # find Ps and Ss
    
    for (i in 1:max_iter){
      P = solve_lin_assign_lpsolve(Q1 %*% S %*% t(Q2))
      P = round(P)
      # S = solve_s2(Q1, Q2)
      S = solve_s(P %*% Q1, Q2)
    }
    
    # if cost function improved, save new solution
    if (cost >= norm(P %*% Q1 %*% S - Q2, type = "F")){
      cost = norm(P %*% Q1 %*% S - Q2, type = "F")
      P_sol = P
    }
  }
  
  diff_P = sum((P_sol-Ps)^2)
  diff_B = norm(P_sol %*% B1 %*% t(P_sol) - B2)
  return(c(cost, diff_P, diff_B))
}


set.seed(123)
num_trials = 100
# \norm{PQ1S - Q2}
avg_cost = 0
# \norm{P-Ps}
avg_diff_P = 0
# \norm{PB1P^T - B2}
avg_diff_B = 0
for (i in 1:num_trials){
  print(i)
  res = testaltmin_SBM(K = 50, max_iter = 20, n = 500)
  avg_cost = avg_cost + res[1]
  P_diff = res[2]
  avg_diff_P = avg_diff_P + P_diff
  B_diff = res[3]
  avg_diff_B = avg_diff_B + B_diff
  print(avg_diff_P/i)
  print(avg_cost/i)
  print(avg_diff_B/i)
}
avg_cost = avg_cost/ num_trials
avg_diff = avg_diff/num_trials
#avg final cost function 
print(avg_cost)

