
library(parallel)
library(pbmcapply)
library(ggplot2)
library(dplyr)
library(Matrix)
# library(RcppHungarian)
library(CVXR)
library(rlang)

# matching_perm = function(C) {
#   out = HungarianSolver(C)
#   perm = out$pairs[,2]
#   perm
# }
run_grid_test = function(fun, nrep = 1, ncores = 1, ...) {
  # params = list(...)
  runs = expand.grid(rep = 1:nrep, ...)
  
  # do.call(rbind, apply(runs, 1, function(run_row) {
  do.call(rbind, pbmclapply(1:nrow(runs), mc.cores = ncores, FUN = function(i) {
    run_row = runs[i, ]
    param_list = list2(!!!run_row)
    out = do.call(fun, param_list[-1])
    data.frame(param_list, out)
  }))
}

fun = function(rep, x, y) {
  list(z = x*y, u = x-y)
}
run_grid_test(fun, nrep = 3, ncores = 4, x=4:5, y =-1:1) 


solve_lin_assign = function(C) {
  K = nrow(C)
  P = Variable(K,K)
  const = list(P >= 0, sum_entries(P,1) == 1, sum_entries(P,2) == 1)
  obj = Maximize(matrix_trace(C %*% P))
  prob = Problem(obj, const)
  result = solve(prob)
  result$getValue(P)
}

rand_perm_mat = function(K) {
  eye = diag(K)
  perm = sample(K,K)
  eye[, perm]
}


set.seed(123)
K = 5
B1 = matrix(runif(K^2),K)
B1  = (B1 + t(B1))/2
out = eigen(B1)
U1 = out$vectors
Ps = rand_perm_mat(K)
B2 = Ps %*% B1 %*% t(Ps)
U2 = eigen(B2)$vectors

Sh = (t(U2) %*% rep(1,K)) / (t(U1) %*% rep(1,K))
U3 = U2 %*% diag(as.vector(Sh)) 
Ps - round(U3 %*% t(U1))


solve_noisy_prob = function(sig) {
  C = U1 %*% t(U3) + matrix(runif(K^2, min=-sig, max=sig), K)
  # C = U1 %*% t(U3) + sig*matrix(rnorm(K^2),K)
  # C = U1 %*% t(U3) + sig*matrix(runif(K^2),K)
  list(err = sum((Ps - solve_lin_assign(C))^2))
}

res = run_grid_test(fun = solve_noisy_prob, nrep = 100, ncores = 3, sig = seq(0, 1, len = 11))  

(avg_res = res %>% 
  group_by(sig) %>% 
  summarise(err = mean(err)))

avg_res %>% 
  ggplot() +
  geom_line(aes(sig, err))



