library(foreach)
library(doParallel)
library(mvtnorm)

#' Quant
#'
#' @param vec vector to compress
#' @param bits number of bits for each coordinate of the vector
#'
#' @return
#' @export
#'
#' @examples
Quant <- function(vec, bits){
  vec_normalized <- vec/norm(vec, type="2")
  refs <- seq(-1,1,length.out=2^bits)
  vec_quantized <- sapply(vec_normalized, function(x) which.min(abs(refs-x)))
  return(list(quant =vec_quantized, norm = norm(vec, type="2")))
}

# EM
#' Classical EM algorithm for finite GMM with fixed covariance 
#'
#' @param Y observed data (n x p, n samples and dimension p)
#' @param L number of mixtures (number of categories)
#' @param alpha_init initial vector (length L) of mixture (categories) proportion
#' @param mu_init initial mean vectors (L x p) 
#' @param Sigma_init initial covariance matrix (p x p)
#' @param maxiter maximum number of iterations (integer)
#'
#' @return A list with the following elements
#' \item{alpha}{inferred vector of mixture proportions (length L)}
#' \item{mu}{inferred vectors of means (L x p)}
#' \item{Sigma}{inferred covariance matrix (p x p)}
EM_gmm <- function(Y,
                   L = 10,
                   alpha_init = NULL,
                   mu_init = NULL,
                   Sigma_init = NULL,
                   maxiter = 1e3) {
  N <- nrow(Y)
  p <- ncol(Y)
  if (is.null(alpha_init)) {
    hatalpha <- rep(1 / L, L)
  } else
    hatalpha <- alpha_init
  if (is.null(mu_init)) {
    hatmu <- matrix(rep(rowMeans(Y), L), nrow = L, byrow = T)
    hatmu <- hatmu + matrix(rnorm(p * L), nrow = L)
  } else
    hatmu <- mu_init
  if (is.null(Sigma_init)) {
    hatSigma <- diag(p)
  } else
    hatSigma <- Sigma_init
  for (t in 1:maxiter) {
    cat("\r",
        "iteration ",
        t,
        "/",
        maxiter,
        " - ",
        round(100 * t / maxiter),
        "%")
    # E-step
    St1 <- sapply(1:N, function(i) {
      vec <- hatalpha * sapply(1:L, function(l) {
        dmvnorm(Y[, i], mean = hatmu[l, ], sigma = hatSigma)
      })
      vec / sum(vec)
    })
    St2 <- lapply(1:N, function(i) {
      as.matrix(St1[, i]) %*% t(as.matrix(Y[, i]))
    })
    St2 <- Reduce('+', St2) / N
    St1 <- rowMeans(St1)
    # M-step
    hatalpha <- St1 / sum(St1)
    hatmu <- sweep(St2, 1, St1, "/")
    # hatSigma <- Y %*% t(Y) / N
    # for (c in 1:L) {
    #   hatSigma <-
    #     hatSigma - St1[c] * as.matrix(hatmu[c,]) %*% t(as.matrix(hatmu[c,]))
    # }
  }
  return(list(
    alpha = hatalpha,
    mu = hatmu,
    Sigma = hatSigma
  ))
}


# naiveEM
#' Federated EM algorithm for finite GMM with fixed covariance 
#'
#'
#' @param Ylist list of locally observed data (list of length n:nb of servers, each element of dimension N_i x p, N_i samples on server i and dimension p)
#' @param groups vector of length \sum_{i=1}^nN_i indicating on which server each data point is located
#' @param L number of mixtures (number of categories)
#' @param alpha_init initial vector (length L) of mixture (categories) proportion
#' @param mu_init initial mean vectors (L x p) 
#' @param Sigma_init initial covariance matrix (p x p)
#' @param maxiter maximum number of iterations (integer)
#' @param gamma learning rate
#' @param nbatch size of mini batches
#' @param alpha learning rate for memory update
#' @param beta momentum
#' @param thresh convergence tolerance
#'
#' @return A list with the following elements
#' \item{alpha}{inferred vector of mixture proportions (length L)}
#' \item{mu}{inferred vectors of means (L x p)}
#' \item{Sigma}{inferred covariance matrix (p x p)}

naiveEM_gmm <- function(Ylist,
                        groups,
                        L = 10,
                        alpha_init = NULL,
                        mu_init = NULL,
                        Sigma_init = NULL,
                        maxiter = 1e3,
                        gamma = 1e-3,
                        nbatch = 20,
                        alpha = 0,
                        beta = 1e-2,
                        thresh = 1e-3,
                        bits = 8) {
  refs <- seq(-1,1,length.out=2^bits)
  Y_ref <- do.call("rbind", Ylist)
  N <- nrow(Y_ref)
  Sigma_ref <- t(Y_ref) %*% Y_ref / N
  Nc <- aggregate(rep(1, N), list(groups), sum)[, 2]
  if (is.null(alpha_init)) {
    hatalpha <- rep(1 / L, L)
  } else
    hatalpha <- alpha_init
  if (is.null(mu_init)) {
    hatmu <- matrix(rep(colMeans(Ylist[[1]]), L), nrow = L, byrow = T)
    hatmu <- hatmu + matrix(rnorm(p * L), nrow = L)
  } else
    hatmu <- mu_init
  if (is.null(Sigma_init)) {
    hatSigma <- diag(p)
  } else
    hatSigma <- Sigma_init
  hatS1 <- rep(0, L)
  hatS2 <- matrix(rep(0, p * L), nrow = L)
  H1 <- rep(0, L)
  H2 <- matrix(rep(0, p * L), nrow = L)
  V1central <- rep(0, L)
  V2central <- matrix(rep(0, L * p), nrow = L)
  wrk_list <- list()
  for (c in 1:n) {
    wrk_list[[c]] <- list()
    wrk_list[[c]]$V1 <- rep(0, L)
    wrk_list[[c]]$V2 <- matrix(rep(0, L * p), nrow = L)
    wrk_list[[c]]$Delta1 <- rep(0, L)
    wrk_list[[c]]$Delta2 <- matrix(rep(0, L * p), nrow = L)
  }
  obj <- NULL
  normh <- NULL
  flag <- TRUE
  iter <- 0
  # Avec distribution
  while (flag) {
    iter <- iter + 1
    if (iter >= maxiter) {
      flag <- FALSE
    }
    cat("\r",
        "iteration ",
        iter,
        "/",
        maxiter,
        " - ",
        round(100 * iter / maxiter),
        "%")
    St1list <- list()
    St2list <- list()
    cores = detectCores()
    cl <- makeCluster(cores[1] - 1) #not to overload your computer
    registerDoParallel(cl)
    
    wrk_list <- foreach(c = 1:n, .packages = "mvtnorm", .export = "Quant") %dopar% {
      batch <- sample(1:Nc[c], nbatch)
      St1 <- sapply(batch, function(ib) {
        vec <- hatalpha * sapply(1:L, function(l) {
          dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
        })
        vec / sum(vec)
      })
      St2 <- lapply(1:nbatch, function(ib) {
        as.matrix(St1[, ib]) %*% t(as.matrix(Ylist[[c]][batch[ib],]))
      })
      St2 <- Reduce('+', St2) / nbatch
      St1 <- rowMeans(St1)
      Delta1 <- St1 - hatS1 - wrk_list[[c]]$V1
      Delta2 <- St2 - hatS2 - wrk_list[[c]]$V2
      Delta1_quant <- Quant(Delta1, bits)
      Delta1_quant <- Delta1_quant[[2]]*refs[Delta1_quant[[1]]]
      Delta2_quant <- Quant(Delta2, bits)
      Delta2_quant <- Delta2_quant[[2]]*matrix(refs[Delta2_quant[[1]]], nrow=L)
      V1 <- V1central
      V2 <- V2central
      list(
        St1 = St1,
        St2 = St2,
        Delta1 = Delta1_quant,
        Delta2 = Delta2_quant,
        V1 = V1,
        V2 = V2
      )
    }
    #stop cluster
    stopCluster(cl)
    # Update statistic hatS at central server
    Delta1list <- lapply(wrk_list, function(x)
      x$Delta1)
    Delta2list <- lapply(wrk_list, function(x)
      x$Delta2)
    H1 <- beta * H1 + Reduce('+', Delta1list) / n + V1central
    H2 <- beta * H2 + Reduce('+', Delta2list) / n + V2central
    val <- norm(H1, type = "2") ^ 2 + norm(H2, type = "F") ^ 2
    if (val <= thresh) {
      flag <- FALSE
    }
    obj <- c(obj, val)
    hatS1 <- hatS1 + gamma * H1
    hatS2 <- hatS2 + gamma * H2
    V1central <- rep(0, L)
    V2central <- matrix(rep(0, L * p), nrow = L)    # Compute map T(hatS)
    hatalpha <- hatS1 / sum(hatS1)
    hatmu <- sweep(hatS2, 1, hatS1, "/")
    
    # Compute h(hatS)
    cores = detectCores()
    cl <- makeCluster(cores[1] - 1) #not to overload your computer
    registerDoParallel(cl)
    
    h_list <- foreach(c = 1:n, .packages = "mvtnorm") %dopar% {
      h_St1 <- sapply(1:Nc[c], function(ib) {
        vec <- hatalpha * sapply(1:L, function(l) {
          dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
        })
        vec / sum(vec)
      })
      h_St2 <- lapply(1:Nc[c], function(ib) {
        as.matrix(h_St1[, ib]) %*% t(as.matrix(Ylist[[c]][ib,]))
      })
      h_St2 <- Reduce('+', h_St2) / Nc[c]
      h_St1 <- rowMeans(h_St1)
      list(
        h_St1 = h_St1,
        h_St2 = h_St2
      )
    }
    #stop cluster
    stopCluster(cl)
    h_S1 <- lapply(h_list, function(x)
      x$h_St1)
    h_S2 <- lapply(h_list, function(x)
      x$h_St2)
    h_S1 <- Reduce('+', h_S1) / n 
    h_S2 <- Reduce('+', h_S2) / n 
    normh <- c(normh, norm(h_S1-hatS1, type = "2") ^ 2 + norm(h_S2-hatS2, type = "F") ^ 2)
    
  }
  return(list(
    alpha = hatalpha,
    mu = hatmu,
    Sigma = hatSigma,
    normH = obj,
    normh = normh
  ))
}



# FedEM
#' Federated EM algorithm for finite GMM with fixed covariance 
#'
#'
#' @param Ylist list of locally observed data (list of length n:nb of servers, each element of dimension N_i x p, N_i samples on server i and dimension p)
#' @param groups vector of length \sum_{i=1}^nN_i indicating on which server each data point is located
#' @param L number of mixtures (number of categories)
#' @param alpha_init initial vector (length L) of mixture (categories) proportion
#' @param mu_init initial mean vectors (L x p) 
#' @param Sigma_init initial covariance matrix (p x p)
#' @param maxiter maximum number of iterations (integer)
#' @param gamma learning rate
#' @param nbatch size of mini batches
#' @param alpha learning rate for memory update
#' @param beta momentum
#' @param thresh convergence tolerance
#'
#' @return A list with the following elements
#' \item{alpha}{inferred vector of mixture proportions (length L)}
#' \item{mu}{inferred vectors of means (L x p)}
#' \item{Sigma}{inferred covariance matrix (p x p)}

FedEM_gmm <- function(Ylist,
                      groups,
                      L = 10,
                      alpha_init = NULL,
                      mu_init = NULL,
                      Sigma_init = NULL,
                      maxiter = 1e3,
                      gamma = 1e-3,
                      nbatch = 20,
                      alpha = 1e-2,
                      beta = 1e-2,
                      thresh = 1e-3,
                      bits = 8) {
  refs <- seq(-1,1,length.out=2^bits)
  Y_ref <- do.call("rbind", Ylist)
  N <- nrow(Y_ref)
  Sigma_ref <- t(Y_ref) %*% Y_ref / N
  Nc <- aggregate(rep(1, N), list(groups), sum)[, 2]
  if (is.null(alpha_init)) {
    hatalpha <- rep(1 / L, L)
  } else
    hatalpha <- alpha_init
  if (is.null(mu_init)) {
    hatmu <- matrix(rep(colMeans(Ylist[[1]]), L), nrow = L, byrow = T)
    hatmu <- hatmu + matrix(rnorm(p * L), nrow = L)
  } else
    hatmu <- mu_init
  if (is.null(Sigma_init)) {
    hatSigma <- diag(p)
  } else
    hatSigma <- Sigma_init
  hatS1 <- rep(0, L)
  hatS2 <- matrix(rep(0, p * L), nrow = L)
  H1 <- rep(0, L)
  H2 <- matrix(rep(0, p * L), nrow = L)
  V1central <- rep(0, L)
  V2central <- matrix(rep(0, L * p), nrow = L)
  wrk_list <- list()
  for (c in 1:n) {
    wrk_list[[c]] <- list()
    wrk_list[[c]]$V1 <- rep(0, L)
    wrk_list[[c]]$V2 <- matrix(rep(0, L * p), nrow = L)
    wrk_list[[c]]$Delta1 <- rep(0, L)
    wrk_list[[c]]$Delta2 <- matrix(rep(0, L * p), nrow = L)
  }
  obj <- NULL
  normh <- NULL
  flag <- TRUE
  iter <- 0
  # Avec distribution
  while (flag) {
    iter <- iter + 1
    if (iter >= maxiter) {
      flag <- FALSE
    }
    cat("\r",
        "iteration ",
        iter,
        "/",
        maxiter,
        " - ",
        round(100 * iter / maxiter),
        "%")
    St1list <- list()
    St2list <- list()
    cores = detectCores()
    cl <- makeCluster(cores[1] - 1) #not to overload your computer
    registerDoParallel(cl)
    
    wrk_list <- foreach(c = 1:n, .packages = "mvtnorm", .export = "Quant") %dopar% {
      batch <- sample(1:Nc[c], nbatch)
      St1 <- sapply(batch, function(ib) {
        vec <- hatalpha * sapply(1:L, function(l) {
          dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
        })
        vec / sum(vec)
      })
      St2 <- lapply(1:nbatch, function(ib) {
        as.matrix(St1[, ib]) %*% t(as.matrix(Ylist[[c]][batch[ib],]))
      })
      St2 <- Reduce('+', St2) / nbatch
      St1 <- rowMeans(St1)
      Delta1 <- St1 - hatS1 - wrk_list[[c]]$V1
      Delta2 <- St2 - hatS2 - wrk_list[[c]]$V2
      Delta1_quant <- Quant(Delta1, bits)
      Delta1_quant <- Delta1_quant[[2]]*refs[Delta1_quant[[1]]]
      Delta2_quant <- Quant(Delta2, bits)
      Delta2_quant <- Delta2_quant[[2]]*matrix(refs[Delta2_quant[[1]]], nrow=L)
      V1 <- wrk_list[[c]]$V1 + alpha * Delta1_quant
      V2 <- wrk_list[[c]]$V2 + alpha * Delta2_quant
      list(
        St1 = St1,
        St2 = St2,
        Delta1 = Delta1_quant,
        Delta2 = Delta2_quant,
        V1 = V1,
        V2 = V2
      )
    }
    #stop cluster
    stopCluster(cl)
    # Update statistic hatS at central server
    Delta1list <- lapply(wrk_list, function(x)
      x$Delta1)
    Delta2list <- lapply(wrk_list, function(x)
      x$Delta2)
    H1 <- beta * H1 + Reduce('+', Delta1list) / n + V1central
    H2 <- beta * H2 + Reduce('+', Delta2list) / n + V2central
    val <- norm(H1, type = "2") ^ 2 + norm(H2, type = "F") ^ 2
    if (val <= thresh) {
      flag <- FALSE
    }
    obj <- c(obj, val)
    hatS1 <- hatS1 + gamma * H1
    hatS2 <- hatS2 + gamma * H2
    V1central <- V1central + alpha * Reduce('+', Delta1list) / n
    V2central <- V2central + alpha * Reduce('+', Delta2list) / n
    # Compute map T(hatS)
    hatalpha <- hatS1 / sum(hatS1)
    hatmu <- sweep(hatS2, 1, hatS1, "/")
    
    # Compute h(hatS)
    cores = detectCores()
    cl <- makeCluster(cores[1] - 1) #not to overload your computer
    registerDoParallel(cl)
    
    h_list <- foreach(c = 1:n, .packages = "mvtnorm") %dopar% {
      h_St1 <- sapply(1:Nc[c], function(ib) {
        vec <- hatalpha * sapply(1:L, function(l) {
          dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
        })
        vec / sum(vec)
      })
      h_St2 <- lapply(1:Nc[c], function(ib) {
        as.matrix(h_St1[, ib]) %*% t(as.matrix(Ylist[[c]][ib,]))
      })
      h_St2 <- Reduce('+', h_St2) / Nc[c]
      h_St1 <- rowMeans(h_St1)
      list(
        h_St1 = h_St1,
        h_St2 = h_St2
      )
    }
    #stop cluster
    stopCluster(cl)
    h_S1 <- lapply(h_list, function(x)
      x$h_St1)
    h_S2 <- lapply(h_list, function(x)
      x$h_St2)
    h_S1 <- Reduce('+', h_S1) / n 
    h_S2 <- Reduce('+', h_S2) / n 
    normh <- c(normh, norm(h_S1-hatS1, type = "2") ^ 2 + norm(h_S2-hatS2, type = "F") ^ 2)
    
  }
  return(list(
    alpha = hatalpha,
    mu = hatmu,
    Sigma = hatSigma,
    normH = obj,
    normh = normh
  ))
}

#VR-FedEM
#' Title
#'
#' Federated EM algorithm for finite GMM with fixed covariance 
#'
#'
#' @param Ylist list of locally observed data (list of length n:nb of servers, each element of dimension N_i x p, N_i samples on server i and dimension p)
#' @param groups vector of length \sum_{i=1}^nN_i indicating on which server each data point is located
#' @param L number of mixtures (number of categories)
#' @param alpha_init initial vector (length L) of mixture (categories) proportion
#' @param mu_init initial mean vectors (L x p) 
#' @param Sigma_init initial covariance matrix (p x p)
#' @param maxiter maximum number of iterations (integer)
#' @param kout number of iterations for outer loop
#' @param kin number of iterations for inner loop
#' @param gamma learning rate
#' @param nbatch size of mini batches
#' @param alpha learning rate for memory update
#' @param beta momentum
#' @param thresh convergence tolerance
#'
#'
#' @return A list with the following elements
#' \item{alpha}{inferred vector of mixture proportions (length L)}
#' \item{mu}{inferred vectors of means (L x p)}
#' \item{Sigma}{inferred covariance matrix (p x p)}

FedSpiderEM_gmm <- function(Ylist,
                            groups,
                            L = 10,
                            alpha_init = NULL,
                            mu_init = NULL,
                            Sigma_init = NULL,
                            kout = 10,
                            kin = 10,
                            gamma = 1e-3,
                            nbatch = 20,
                            alpha = 1e-2,
                            beta = 1e-2,
                            thresh = 1e-3,
                            bits = 8) {
  refs <- seq(-1,1,length.out=2^bits)
  Y_ref <- do.call("rbind", Ylist)
  N <- nrow(Y_ref)
  Sigma_ref <- t(Y_ref) %*% Y_ref / N
  Nc <- aggregate(rep(1, N), list(groups), sum)[, 2]
  hatS1 <- rep(1 / L, L)
  hatS2 <- matrix(rep(colMeans(Ylist[[1]]), L), nrow = L, byrow = T)
  hatS2 <- hatS2 + matrix(rnorm(p * L), nrow = L)
  hatalpha <- hatS1 / sum(hatS1)
  hatmu <- sweep(hatS2, 1, hatS1, "/")
  if (is.null(Sigma_init)) {
    hatSigma <- diag(p)
  } else
    hatSigma <- Sigma_init
  H1 <- rep(0, L)
  H2 <- matrix(rep(0, p * L), nrow = L)
  V1central <- rep(0, L)
  V2central <- matrix(rep(0, L * p), nrow = L)
  obj <- NULL
  flag_out <- TRUE
  iter_out <- 0
  obj_out <- list()
  normh_out <- list()
  # Avec distribution
  wrk_list <- list()
  for (c in 1:n) {
    wrk_list[[c]] <- list()
    wrk_list[[c]]$V1 <- rep(0, L)
    wrk_list[[c]]$V2 <- matrix(rep(0, L * p), nrow = L)
    wrk_list[[c]]$Delta1 <- rep(0, L)
    wrk_list[[c]]$Delta2 <- matrix(rep(0, L * p), nrow = L)
  }
  cores = detectCores()
  cl <- makeCluster(cores[1] - 1) #not to overload your computer
  registerDoParallel(cl)
  wrk_list <- foreach(c = 1:n, .packages = "mvtnorm") %dopar% {
    St1 <- sapply(1:Nc[c], function(ib) {
      vec <- hatalpha * sapply(1:L, function(l) {
        dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
      })
      vec / sum(vec)
    })
    St2 <- lapply(1:Nc[c], function(ib) {
      as.matrix(St1[, ib]) %*% t(as.matrix(Ylist[[c]][ib,]))
    })
    St2 <- Reduce('+', St2) / Nc[c]
    St1 <- rowMeans(St1)
    list(
      St1 = St1,
      St2 = St2,
      Delta1 = wrk_list[[c]]$Delta1,
      Delta2 = wrk_list[[c]]$Delta2,
      V1 = wrk_list[[c]]$V1,
      V2 = wrk_list[[c]]$V2
    )
  }
  #stop cluster
  stopCluster(cl)
  print("first par OK")
  tmpalpha <- hatalpha
  tmpmu <- hatmu
  iter_in <- 0
  obj <- NULL
  while (flag_out) {
    iter_out <- iter_out + 1
    if (iter_out >= kout) {
      flag_out <- FALSE
    }
    cat(
      "\r",
      "outer iteration ",
      iter_out,
      "/",
      kout,
      " - ",
      "inner iteration ",
      iter_in,
      "/",
      kin,
      " - ",
      round(100 * iter_out / kout),
      "%"
    )
    
    flag_in <- TRUE
    iter_in <- 0
    obj_in <- NULL
    normh_in <- NULL
    while (flag_in) {
      iter_in <- iter_in + 1
      if (iter_in >= kin) {
        flag_in <- FALSE
      }
      cat(
        "\r",
        "outer iteration ",
        iter_out,
        "/",
        kout,
        " - ",
        "inner iteration ",
        iter_in,
        "/",
        kin,
        " - ",
        round(100 * iter_out / kout),
        "%"
      )
      cores = detectCores()
      cl <- makeCluster(cores[1] - 1) #not to overload your computer
      registerDoParallel(cl)
      wrk_list <- foreach(c = 1:n, .packages = "mvtnorm", .export = "Quant") %dopar% {
        batch <- sample(1:Nc[c], nbatch)
        Stk1 <- sapply(batch, function(ib) {
          vec <- hatalpha * sapply(1:L, function(l) {
            dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
          })
          vec / sum(vec)
        })
        Stk0 <- sapply(batch, function(ib) {
          vec <- tmpalpha * sapply(1:L, function(l) {
            dmvnorm(Ylist[[c]][ib,], mean = tmpmu[l, ], sigma = hatSigma)
          })
          vec / sum(vec)
        })
        St2 <- lapply(1:nbatch, function(ib) {
          res <-
            as.matrix(Stk1[, ib]) %*% t(as.matrix(Ylist[[c]][batch[ib],]))
          res <-
            res - as.matrix(Stk0[, ib]) %*% t(as.matrix(Ylist[[c]][batch[ib],]))
          return(res)
        })
        St2 <- wrk_list[[c]]$St2 + Reduce('+', St2) / nbatch
        St1 <- wrk_list[[c]]$St1 + rowMeans(Stk1 - Stk0)
        Delta1 <- St1 - hatS1 - wrk_list[[c]]$V1
        Delta2 <- St2 - hatS2 - wrk_list[[c]]$V2
        Delta1_quant <- Quant(Delta1, bits)
        Delta1_quant <- Delta1_quant[[2]]*refs[Delta1_quant[[1]]]
        Delta2_quant <- Quant(Delta2, bits)
        Delta2_quant <- Delta2_quant[[2]]*matrix(refs[Delta2_quant[[1]]], nrow=L)
        V1 <- wrk_list[[c]]$V1 + alpha * Delta1_quant
        V2 <- wrk_list[[c]]$V2 + alpha * Delta2_quant
        list(
          St1 = St1,
          St2 = St2,
          Delta1 = Delta1_quant,
          Delta2 = Delta2_quant,
          V1 = V1,
          V2 = V2
        )
      }
      #stop cluster
      stopCluster(cl)
      # Update statistic hatS at central server
      Delta1list <- lapply(wrk_list, function(x)
        x$Delta1)
      Delta2list <- lapply(wrk_list, function(x)
        x$Delta2)
      H1 <- Reduce('+', Delta1list) / n + V1central
      H2 <- Reduce('+', Delta2list) / n + V2central
      val <- norm(H1, type = "2") ^ 2 + norm(H2, type = "F") ^ 2
      if (val <= thresh) {
        flag_in <- FALSE
        flag_out <- FALSE
      }
      obj_in <- c(obj_in, val)
      hatS1 <- hatS1 + gamma * H1
      hatS2 <- hatS2 + gamma * H2
      V1central <- V1central + alpha * Reduce('+', Delta1list) / n
      V2central <- V2central + alpha * Reduce('+', Delta2list) / n
      # Compute map T(hatS)
      tmpalpha <- hatalpha
      tmpmu <- hatmu
      hatalpha <- hatS1 / sum(hatS1)
      hatmu <- sweep(hatS2, 1, hatS1, "/")
      
      
      # Compute h(hatS)
      cores = detectCores()
      cl <- makeCluster(cores[1] - 1) #not to overload your computer
      registerDoParallel(cl)
      
      h_list <- foreach(c = 1:n, .packages = "mvtnorm") %dopar% {
        h_St1 <- sapply(1:Nc[c], function(ib) {
          vec <- hatalpha * sapply(1:L, function(l) {
            dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
          })
          vec / sum(vec)
        })
        h_St2 <- lapply(1:Nc[c], function(ib) {
          as.matrix(h_St1[, ib]) %*% t(as.matrix(Ylist[[c]][ib,]))
        })
        h_St2 <- Reduce('+', h_St2) / Nc[c]
        h_St1 <- rowMeans(h_St1)
        list(
          h_St1 = h_St1,
          h_St2 = h_St2
        )
      }
      #stop cluster
      stopCluster(cl)
      h_S1 <- lapply(h_list, function(x)
        x$h_St1)
      h_S2 <- lapply(h_list, function(x)
        x$h_St2)
      h_S1 <- Reduce('+', h_S1) / n 
      h_S2 <- Reduce('+', h_S2) / n 
      normh_in <- c(normh_in, norm(h_S1-hatS1, type = "2") ^ 2 + norm(h_S2-hatS2, type = "F") ^ 2)
      
    }
    cores = detectCores()
    cl <- makeCluster(cores[1] - 1) #not to overload your computer
    registerDoParallel(cl)
    obj_out[[iter_out]] <- obj_in
    wrk_list <- foreach(c = 1:n, .packages = "mvtnorm") %dopar% {
      St1 <- sapply(1:Nc[c], function(ib) {
        vec <- hatalpha * sapply(1:L, function(l) {
          dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
        })
        vec / sum(vec)
      })
      St2 <- lapply(1:Nc[c], function(ib) {
        as.matrix(St1[, ib]) %*% t(as.matrix(Ylist[[c]][ib,]))
      })
      St2 <- Reduce('+', St2) / Nc[c]
      St1 <- rowMeans(St1)
      list(
        St1 = St1,
        St2 = St2,
        Delta1 = wrk_list[[c]]$Delta1,
        Delta2 = wrk_list[[c]]$Delta2,
        V1 = wrk_list[[c]]$V1,
        V2 = wrk_list[[c]]$V2
      )
    }
    #stop cluster
    stopCluster(cl)
    normh_out[[iter_out]] <- normh_in
  }
  return(list(
    alpha = hatalpha,
    mu = hatmu,
    Sigma = hatSigma,
    normH = obj_out,
    normh = normh_out
  ))
}
