spec_clust_ase <- function(A, K, d, niter = 50){
  U <- ase(A, d)
  return(kmeans(U, K, iter.max = niter)$cluster)
}

agg_graphs <-function(Alist, zhatlist){
  sample_size <- length(Alist)
  #align against first graph
  Bhlist <- mapply(estim_dcsbm, Alist, zhatlist)[c(TRUE, FALSE)]
  
  Ptlist <- lapply(Bhlist[-1], matching, Bhlist[[1]])
  
  agg_block_sum <- block_sums(Alist[[1]], zhatlist[[1]])
  agg_block_size <- block_sizes(zhatlist[[1]])
  
  for (i in 2:sample_size){
    zmap <- mat2perm(Ptlist[[i-1]])
    cur_block_sum <- block_sums(Alist[[i]], zmap[zhatlist[[i]]])
    agg_block_sum <- agg_block_sum + cur_block_sum
    agg_block_size <- agg_block_size + block_sizes(zmap[zhatlist[[i]]])
  }
  return(list(agg_block_sum, agg_block_size))
}

agg_sbm_test <- function(A1list, A2list, z1hatlist, z2hatlist, tau = 0){
  temp1 <- agg_graphs(A1list, z1hatlist)
  temp2 <- agg_graphs(A2list, z2hatlist)
  
  B1agg <- temp1[[1]]/temp1[[2]]
  B2agg <- temp2[[1]]/temp2[[2]]
  
  Ptagg <- t(matching(B1agg, B2agg))
  
  S1 <- temp1[[1]]
  S2 <- Ptagg %*%  temp2[[1]]  %*% t(Ptagg)
  
  m1 <- temp1[[2]] 
  m2 <- Ptagg %*% temp2[[2]] %*% t(Ptagg)
  
  Bhat <- (S1 + S2) / (m1 + m2)
  Sigma2_hat <- Bhat * (1 - Bhat)
  
  mbar <- 1 / (1 / m1 + 1 / m2)
  
  mat <- mbar / Sigma2_hat * (S1 / m1 - S2 / m2)^2
  upper_tri_mask <- upper.tri(mat, diag = TRUE)
  
  That <- sum(mat[upper_tri_mask & Sigma2_hat > tau], na.rm = TRUE)
  That
}

agg_ase_dist <- function(A1list, A2list, d, sigma = 1, ase_repr = TRUE){
  if (ase_repr){
    Ase1list <- lapply(A1list, ase, d)
    Ase2list <- lapply(A2list, ase, d)
  }
  else {
    Ase1list <- lapply(A1list, spec_repr, d)
    Ase2list <- lapply(A2list, spec_repr, d)
  }
  
  
  
  avg_dist <- 0 
  for (i in 1:sample_size){
    for (j in 1:sample_size){
      avg_dist <- avg_dist + fast_mmd(Ase1list[[i]], Ase2list[[j]], sigma)$biased
    }
  }
  dist <- avg_dist/sample_size^2
  dist
}


agg_nlcm_dist <- function(A1list, A2list, d){
  LogM1list <- lapply(A1list, log_moment, d)
  LogM2list <- lapply(A2list, log_moment, d)
  
  avg_dist <- 0 
  for (i in 1:sample_size){
    for (j in 1:sample_size){
      avg_dist <- avg_dist + l2norm_squared(LogM1list[[i]]- LogM2list[[i]])
    }
  }
  dist <- avg_dist/sample_size^2
  dist
}






