library(nett)
library(RSpectra)
library(lpSolve)
library(Matrix)

# Aligning against B2,i.e. assume that B2 = P B1 P^T

#Solve argmax tr(PQ)
#note: lp.assign(C) optimizes c_{i,j} X_{i,j} which is equal to tr(C^T X) = tr(X C^T)
solve_lap<-function(C){
  lp = lp.assign(t(C), direction = "max")
  return(lp$solution)
}

#Generate random permutation matrix
rand_perm_mat = function(K) {
  Id = diag(K)
  perm = sample(K,K)
  Id[, perm]
}

#Find \St
get_S <- function(Q1, Q2){
  K = nrow(Q1)
  S_hat = (t(Q2) %*% rep(1,K)) / (t(Q1) %*% rep(1,K))
  return(sign(matrix(Diagonal(K, matrix(S_hat)), K, K)))
}

#find eps_delta parameters of B
get_eps_delta <-function(B){
  EVD = eigen(B)
  K = nrow(B)
  Q = t(EVD$vectors)
  sums = t(Q) %*% rep(1, K)
  min_ = min(abs(sums))
  #max_ = max(abs(1/sums))
  #eps = min(min_, 1/max_)
  eps = min_
  
  Lambda = EVD$values
  K = nrow(B)
  delta = min(Lambda[-K] - Lambda[-1])
  
  return(c(eps, delta))
  
}

#check condition 19
check_estim <-function(B, n){
  chk = get_eps_delta(B)
  eps = chk[1]
  delta = chk[2]
  K = nrow(B)

  z = sample(1:K, n, replace = TRUE)
  A = sample_dcsbm(z, B)
  Bhat = estim_dcsbm(A, z)$B
  
  diff = norm(B - Bhat, type = "F")
  return(list("condition" = diff - eps*delta/(2*sqrt(2)*K), "eps" = eps, "delta" = delta))
}

#convert permutation matrix into a map
permutation_mapping <-function(P){
  zmap = rep(0, K)
  for (i in 1:K){
    for (j in 1:K){
      if (P[i, j] != 0){
        zmap[i] = j
      }
    }
  }
  return(zmap)
}

#Find misclassification rate of zhat against z
#solving min_P z - zhat P 
#where z, zhat are n*K  membership matrices 
misclassification_rate <-function(z, zhat, K){
  z_mat <- label_vec2mat(z, K)
  zhat_mat <- label_vec2mat(zhat, K)
  confusion_cost <- t(zhat_mat) %*% z_mat
  lp = lp.assign(confusion_cost, direction = "max")
  perm <- lp$solution
  label_mapping <- permutation_mapping(perm)
  zhat_aligned = label_mapping[zhat]
  return(sum(z != zhat_aligned)/length(z))
}

#Matching B1hat against B2hat
matching <-function(B1hat, B2hat, z1hat, z2hat){
  #z1hat = z1
  #z2hat = z2

  o1 = eigen(B1hat)
  o2 = eigen(B2hat)
  #Note: eigen$vectors returns Q^T not Q (check sum(Q1[,1]^2))
  Q1 = t(o1$vectors) 
  Q2 = t(o2$vectors)  
  
  S_hat = get_S(Q1, Q2)
  Pt = solve_lap((Q1) %*% S_hat %*% t(Q2))
  return(Pt)
}

#Calculate That_K
get_That <-function(A1, A2, z1hat, z2hat, Pt){
  zmap <- permutation_mapping(Pt)
  #get block sums
  S_1 = compute_block_sums(A1, zmap[z1hat])
  S_2 = compute_block_sums(A2, z2hat)
  # count nhat_{1,k} * nhat_{1,l}
  ns1 <- as.vector(table(zmap[z1hat]))
  ns2 <- as.vector(table(z2hat))
  m1 <- pmax(ns1 %*% t(ns1) - diag(ns1), 1)
  m2 <- pmax(ns2 %*% t(ns2) - diag(ns2), 1)
  
  Bhat = (S_1 + S_2)/(m1 + m2)
  Sigma_hat = Bhat*(1- Bhat)
  
  That = 0
  for (i in 1:K){
    for (j in i:K){
      # mbar half of the harmonic mean.
      mbar = m1[i,j] * m2[i,j]/(m1[i,j] + m2[i,j])
      addd = (mbar/Sigma_hat[i,j])*((S_1[i,j]/m1[i,j]) - (S_2[i,j]/m2[i,j]))^2
      That = That + addd
      # addd is chi squared with 1 df
      print(addd)
    }
  }
  return(list("That"= That, "Bhat" = Bhat))
}

simulation <- function(num_exp, B, K, n, alpha){
  #m1_rates = rep(0, num_exp)
  #m2_rates = rep(0, num_exp)
  acc = 0
  threshold = qchisq(p=alpha, df=K*(K+1)/2, lower.tail=FALSE)
  for (i in 1:num_exp){
    z1 = sample(1:K, n, replace = TRUE)
    z2 = sample(1:K, n, replace = TRUE)
    
    A1 = sample_dcsbm(z1, B)
    A2 = sample_dcsbm(z2, B)
    
    z1hat = spec_clust(A1, K, niter = 30)
    z2hat = spec_clust(A2, K, niter = 30)
    B1hat = estim_dcsbm(A1, z1hat)$B
    B2hat = estim_dcsbm(A2, z2hat)$B
    
    Pt = matching(B1hat, B2hat, z1hat, z2hat)
    
    #m1_rates[i] = misclassification_rate(z1, z1hat, K)
    #m2_rates[i] = misclassification_rate(z2, z2hat, K)
    
    chisq = get_That(A1, A2, z1hat, z2hat, Pt)$That
    if (chisq < threshold){
      acc = acc + 1
    }
  }
  acc/num_exp
}


K = 2
n = 1000
B = matrix(runif(K*K, 0.2, 0.7), nrow = K)
B = (B + t(B))/2
check_estim(B, n)


z1 = sample(1:K, n, replace = TRUE)
z2 = sample(1:K, n, replace = TRUE)

A1 = sample_dcsbm(z1, B)
A2 = sample_dcsbm(z2, B)

z1hat = spec_clust(A1, K, niter = 30)
z2hat = spec_clust(A2, K, niter = 30)
B1hat = estim_dcsbm(A1, z1hat)$B
B2hat = estim_dcsbm(A2, z2hat)$B

o1 = eigen(B1hat)
Q1 = t(o1$vectors)
L1 = diag(o1$values)

o2 = eigen(B2hat)
Q2 = t(o2$vectors)
L2 = diag(o2$values)

S = get_S((Q1), (Q2))
Pt = solve_lap((Q1) %*% S %*% t(Q2))

Pt %*% B1hat %*% t(Pt) - B2hat
B1hat
B2hat
Pt
matching(B1hat, B2hat, z1hat, z2hat)
get_That(A1, A2, z1hat, z2hat, Pt)
misclassification_rate(z1, z1hat, K)
