library(parallel)
library(pbmcapply)
library(ggplot2)
library(dplyr)
library(Matrix)
library(stats)
# library(RcppHungarian)
library(CVXR)
library(rlang)
library(lpSolve)

norm_vec <- function(x) (sum(x^2))

rand_perm_mat = function(K) {
  eye = diag(K)
  perm = sample(K,K)
  eye[, perm]
}

# solving with lpSolve
solve_lin_assign_lpsolve = function(C) {
  lp = lp.assign(t(C), direction = "max")
  return(lp$solution)
}
solve_lin_assign = function(C) {
  K = nrow(C)
  P = Variable(K,K)
  const = list(P >= 0, sum_entries(P,1) == 1, sum_entries(P,2) == 1)
  obj = Maximize(matrix_trace(C %*% P))
  prob = Problem(obj, const)
  result = solve(prob, verbose = TRUE, num_iter = 100000)
  result$getValue(P)
}

# min_S |(PQ)S - C|
solve_s = function(PQ, C) {
  K = nrow(C)
  S = rep(0, K)
  for (i in 1:K){
    if (norm_vec(PQ[,i] - C[,i]) >= norm_vec(PQ[,i] + C[,i])){
      S[i] = -1
    } else  {
      S[i] = 1
    }
  }
  Matrix::diag(S)
}

set.seed(123)
K = 50

#try random init for S if stuck in local minima
testaltmin2 = function(K, sig, max_iter){
  B1 = matrix(runif(K^2),K)
  B1  = (B1 + t(B1))/2
  Ps = rand_perm_mat(K)
  
  delta = matrix(runif(K^2, min=-sig, max=sig), K)
  delta = (delta + t(delta))/2
  B2 = Ps %*% B1 %*% t(Ps) + delta
  
  o1 = eigen(B1)
  o2 = eigen(B2)
  Q1 = o1$vectors 
  Q2 = o2$vectors  
  
  #Ss = diag(sign(runif(K, -1, 1)))
  #Q2 = Ps %*% Q1  %*% Ss + delta
  
  
  S = diag(sign(runif(K, -1, 1)))
  # find Ps and Ss
  
  cost = 0
  P_sol = matrix(0,K,K)
  P = matrix(0,K,K)
  
  for (i in 1:max_iter){
    P = solve_lin_assign_lpsolve(Q1 %*% S %*% t(Q2))
    # S = solve_s2(Q1, Q2)
    S = solve_s(P %*% Q1, Q2)
    cost = norm(P %*% Q1 %*% S - Q2, type = "F")
  }
  P_sol = P
  # try diff init point for S if stuck in local minimum
  if (cost != 0){
    S = diag(sign(runif(K, -1, 1)))
    # find Ps and Ss
    
    for (i in 1:max_iter){
      P = solve_lin_assign_lpsolve(Q1 %*% S %*% t(Q2))
      P = round(P)
      # S = solve_s2(Q1, Q2)
      S = solve_s(P %*% Q1, Q2)
    }
    # if cost function improved, save new solution
    if (cost >= norm(P %*% Q1 %*% S - Q2, type = "F")){
      cost = norm(P %*% Q1 %*% S - Q2, type = "F")
      P_sol = P
    }
  }
  
  # try diff init point for S if stuck in local minimum
  if (cost != 0){
    S = diag(sign(runif(K, -1, 1)))
    # find Ps and Ss
    
    for (i in 1:max_iter){
      P = solve_lin_assign_lpsolve(Q1 %*% S %*% t(Q2))
      P = round(P)
      # S = solve_s2(Q1, Q2)
      S = solve_s(P %*% Q1, Q2)
    }
    # if cost function improved, save new solution
    if (cost >= norm(P %*% Q1 %*% S - Q2, type = "F")){
      cost = norm(P %*% Q1 %*% S - Q2, type = "F")
      P_sol = P
    }
  }
  
  # try diff init point for S if stuck in local minimum
  if (cost != 0){
    S = diag(sign(runif(K, -1, 1)))
    # find Ps and Ss
    
    for (i in 1:max_iter){
      P = solve_lin_assign_lpsolve(Q1 %*% S %*% t(Q2))
      P = round(P)
      # S = solve_s2(Q1, Q2)
      S = solve_s(P %*% Q1, Q2)
    }
    
    # if cost function improved, save new solution
    if (cost >= norm(P %*% Q1 %*% S - Q2, type = "F")){
      cost = norm(P %*% Q1 %*% S - Q2, type = "F")
      P_sol = P
    }
  }
  
  difff = sum((P_sol-Ps)^2)
  return(c(cost, difff))
}

num_trials = 100
# \norm{PQ1S - Q2}
avg_cost = 0
# \norm{P-Ps}
avg_diff = 0
for (i in 1:num_trials){
  print(i)
  res = testaltmin2(K = 100, sig = 0.005, max_iter = 20)
  avg_cost = avg_cost + res[1]
  P_diff = res[2]
  avg_diff = avg_diff + P_diff
  print(avg_diff/i)
}
avg_cost = avg_cost/ num_trials
avg_diff = avg_diff/num_trials
#avg final cost function 
print(avg_cost)

