library(foreach)
library(doParallel)
library(mvtnorm)

# FedEM
#' Federated EM algorithm for finite GMM with fixed covariance 
#'
#'
#' @param Ylist list of locally observed data (list of length n:nb of servers, each element of dimension N_i x p, N_i samples on server i and dimension p)
#' @param groups vector of length \sum_{i=1}^nN_i indicating on which server each data point is located
#' @param L number of mixtures (number of categories)
#' @param alpha_init initial vector (length L) of mixture (categories) proportion
#' @param mu_init initial mean vectors (L x p) 
#' @param Sigma_init initial covariance matrix (p x p)
#' @param maxiter maximum number of iterations (integer)
#' @param gamma learning rate
#' @param nbatch size of mini batches
#' @param alpha learning rate for memory update
#' @param beta momentum
#' @param thresh convergence tolerance
#'
#' @return A list with the following elements
#' \item{alpha}{inferred vector of mixture proportions (length L)}
#' \item{mu}{inferred vectors of means (L x p)}
#' \item{Sigma}{inferred covariance matrix (p x p)}
FedEM_gmm <- function(Ylist,
                      groups,
                      L = 10,
                      alpha_init = NULL,
                      mu_init = NULL,
                      Sigma_init = NULL,
                      maxiter = 1e3,
                      gamma = 1e-3,
                      nbatch = 20,
                      alpha = 1e-2,
                      beta = 1e-2,
                      thresh = 1e-3) {
  alpha_list <- list()
  Y_ref <- do.call("rbind", Ylist)
  N <- nrow(Y_ref)
  Sigma_ref <- t(Y_ref) %*% Y_ref / N
  Nc <- aggregate(rep(1, N), list(groups), sum)[, 2]
  if (is.null(alpha_init)) {
    hatalpha <- rep(1 / L, L)
  } else
    hatalpha <- alpha_init
  if (is.null(mu_init)) {
    hatmu <- matrix(rep(colMeans(Ylist[[1]]), L), nrow = L, byrow = T)
    hatmu <- hatmu + matrix(rnorm(p * L), nrow = L)
  } else
    hatmu <- mu_init
  if (is.null(Sigma_init)) {
    hatSigma <- diag(p)
  } else
    hatSigma <- Sigma_init
  hatS1 <- rep(0, L)
  hatS2 <- matrix(rep(0, p * L), nrow = L)
  H1 <- rep(0, L)
  H2 <- matrix(rep(0, p * L), nrow = L)
  V1central <- rep(0, L)
  V2central <- matrix(rep(0, L * p), nrow = L)
  wrk_list <- list()
  for (c in 1:n) {
    wrk_list[[c]] <- list()
    wrk_list[[c]]$V1 <- rep(0, L)
    wrk_list[[c]]$V2 <- matrix(rep(0, L * p), nrow = L)
    wrk_list[[c]]$Delta1 <- rep(0, L)
    wrk_list[[c]]$Delta2 <- matrix(rep(0, L * p), nrow = L)
  }
  obj <- NULL
  flag <- TRUE
  iter <- 0
  # Avec distribution
  while (flag) {
    iter <- iter + 1
    if (iter >= maxiter) {
      flag <- FALSE
    }
    cat("\r",
        "iteration ",
        iter,
        "/",
        maxiter,
        " - ",
        round(100 * iter / maxiter),
        "%")
    St1list <- list()
    St2list <- list()
    cores = detectCores()
    cl <- makeCluster(cores[1] - 1) #not to overload your computer
    registerDoParallel(cl)
    
    wrk_list <- foreach(i = 1:n, .packages = "mvtnorm") %dopar% {
      batch <- sample(1:Nc[c], nbatch)
      St1 <- sapply(batch, function(ib) {
        vec <- hatalpha * sapply(1:L, function(l) {
          dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
        })
        vec / sum(vec)
      })
      St2 <- lapply(1:nbatch, function(ib) {
        as.matrix(St1[, ib]) %*% t(as.matrix(Ylist[[c]][batch[ib],]))
      })
      St2 <- Reduce('+', St2) / nbatch
      St1 <- rowMeans(St1)
      Delta1 <- St1 - hatS1 - wrk_list[[c]]$V1
      Delta2 <- St2 - hatS2 - wrk_list[[c]]$V2
      V1 <- wrk_list[[c]]$V1 + alpha * Delta1
      V2 <- wrk_list[[c]]$V2 + alpha * Delta2
      list(
        St1 = St1,
        St2 = St2,
        Delta1 = Delta1,
        Delta2 = Delta2,
        V1 = V1,
        V2 = V2
      )
    }
    #stop cluster
    stopCluster(cl)
    # Update statistic hatS at central server
    Delta1list <- lapply(wrk_list, function(x)
      x$Delta1)
    Delta2list <- lapply(wrk_list, function(x)
      x$Delta2)
    H1 <- beta * H1 + Reduce('+', Delta1list) / n + V1central
    H2 <- beta * H2 + Reduce('+', Delta2list) / n + V2central
    val <- norm(H1, type = "2") ^ 2 + norm(H2, type = "F") ^ 2
    if (val <= thresh) {
      flag <- FALSE
    }
    obj <- c(obj, val)
    hatS1 <- hatS1 + gamma * H1
    hatS2 <- hatS2 + gamma * H2
    V1central <- V1central + alpha * Reduce('+', Delta1list) / n
    V2central <- V2central + alpha * Reduce('+', Delta2list) / n
    # Compute map T(hatS)
    hatalpha <- hatS1 / sum(hatS1)
    hatmu <- sweep(hatS2, 1, hatS1, "/")
    alpha_list[[iter]] <- hatalpha
    
  }
  return(list(
    alpha = alpha_list,
    mu = hatmu,
    Sigma = hatSigma,
    normH = obj
  ))
}

#VR-FedEM
#' Title
#'
#' Federated EM algorithm for finite GMM with fixed covariance 
#'
#'
#' @param Ylist list of locally observed data (list of length n:nb of servers, each element of dimension N_i x p, N_i samples on server i and dimension p)
#' @param groups vector of length \sum_{i=1}^nN_i indicating on which server each data point is located
#' @param L number of mixtures (number of categories)
#' @param alpha_init initial vector (length L) of mixture (categories) proportion
#' @param mu_init initial mean vectors (L x p) 
#' @param Sigma_init initial covariance matrix (p x p)
#' @param maxiter maximum number of iterations (integer)
#' @param kout number of iterations for outer loop
#' @param kin number of iterations for inner loop
#' @param gamma learning rate
#' @param nbatch size of mini batches
#' @param alpha learning rate for memory update
#' @param beta momentum
#' @param thresh convergence tolerance
#'
#'
#' @return A list with the following elements
#' \item{alpha}{inferred vector of mixture proportions (length L)}
#' \item{mu}{inferred vectors of means (L x p)}
#' \item{Sigma}{inferred covariance matrix (p x p)}
FedSpiderEM_gmm <- function(Ylist,
                            groups,
                            L = 10,
                            alpha_init = NULL,
                            mu_init = NULL,
                            Sigma_init = NULL,
                            kout = 10,
                            kin = 10,
                            gamma = 1e-3,
                            nbatch = 20,
                            alpha = 1e-2,
                            beta = 1e-2,
                            thresh = 1e-3) {
  
  alpha_list <- list()
  alpha_counter <- 0
  Y_ref <- do.call("rbind", Ylist)
  N <- nrow(Y_ref)
  Sigma_ref <- t(Y_ref) %*% Y_ref / N
  Nc <- aggregate(rep(1, N), list(groups), sum)[, 2]
  if (is.null(alpha_init)) {
    hatalpha <- rep(1 / L, L)
  } else
    hatalpha <- alpha_init
  if (is.null(mu_init)) {
    hatmu <- matrix(rep(colMeans(Ylist[[1]]), L), nrow = L, byrow = T)
    hatmu <- hatmu + matrix(rnorm(p * L), nrow = L)
  } else
    hatmu <- mu_init
  if (is.null(Sigma_init)) {
    hatSigma <- diag(p)
  } else
    hatSigma <- Sigma_init
  hatS1 <- rep(0, L)
  hatS2 <- matrix(rep(0, p * L), nrow = L)
  H1 <- rep(0, L)
  H2 <- matrix(rep(0, p * L), nrow = L)
  V1central <- rep(0, L)
  V2central <- matrix(rep(0, L * p), nrow = L)
  wrk_list <- list()
  for (c in 1:n) {
    wrk_list[[c]] <- list()
    wrk_list[[c]]$V1 <- rep(0, L)
    wrk_list[[c]]$V2 <- matrix(rep(0, L * p), nrow = L)
    wrk_list[[c]]$Delta1 <- rep(0, L)
    wrk_list[[c]]$Delta2 <- matrix(rep(0, L * p), nrow = L)
  }
  obj <- NULL
  flag <- TRUE
  iter <- 0
  # Avec distribution
  while (flag) {
    iter <- iter + 1
    if (iter >= 70) {
      flag <- FALSE
    }
    cat("\r",
        "burn in ",
        iter,
        "/",
        70,
        " - ",
        round(100 * iter / 70),
        "%")
    St1list <- list()
    St2list <- list()
    cores = detectCores()
    cl <- makeCluster(cores[1] - 1) #not to overload your computer
    registerDoParallel(cl)
    
    wrk_list <- foreach(i = 1:n, .packages = "mvtnorm") %dopar% {
      batch <- sample(1:Nc[c], nbatch)
      St1 <- sapply(batch, function(ib) {
        vec <- hatalpha * sapply(1:L, function(l) {
          dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
        })
        vec / sum(vec)
      })
      St2 <- lapply(1:nbatch, function(ib) {
        as.matrix(St1[, ib]) %*% t(as.matrix(Ylist[[c]][batch[ib],]))
      })
      St2 <- Reduce('+', St2) / nbatch
      St1 <- rowMeans(St1)
      Delta1 <- St1 - hatS1 - wrk_list[[c]]$V1
      Delta2 <- St2 - hatS2 - wrk_list[[c]]$V2
      V1 <- wrk_list[[c]]$V1 + alpha * Delta1
      V2 <- wrk_list[[c]]$V2 + alpha * Delta2
      list(
        St1 = St1,
        St2 = St2,
        Delta1 = Delta1,
        Delta2 = Delta2,
        V1 = V1,
        V2 = V2
      )
    }
    #stop cluster
    stopCluster(cl)
    # Update statistic hatS at central server
    Delta1list <- lapply(wrk_list, function(x)
      x$Delta1)
    Delta2list <- lapply(wrk_list, function(x)
      x$Delta2)
    H1 <- beta * H1 + Reduce('+', Delta1list) / n + V1central
    H2 <- beta * H2 + Reduce('+', Delta2list) / n + V2central
    val <- norm(H1, type = "2") ^ 2 + norm(H2, type = "F") ^ 2
    obj <- c(obj, val)
    hatS1 <- hatS1 + gamma * H1
    hatS2 <- hatS2 + gamma * H2
    V1central <- V1central + alpha * Reduce('+', Delta1list) / n
    V2central <- V2central + alpha * Reduce('+', Delta2list) / n
    # Compute map T(hatS)
    hatalpha <- hatS1 / sum(hatS1)
    hatmu <- sweep(hatS2, 1, hatS1, "/")
    alpha_counter <- alpha_counter+1
    alpha_list[[alpha_counter]] <- hatalpha
  }
  objoo <- obj
  tmpalpha <- hatalpha
  tmpmu <- hatmu
  iter_in <- 0
  obj <- list()
  obj[[1]] <- objoo
  val <- 1
  iter_out <- 1
  flag_out <- TRUE
  while (flag_out) {
    iter_out <- iter_out + 1
    if (iter_out >= kout) {
      flag_out <- FALSE
    }
    cat(
      "\r",
      "outer iteration ",
      iter_out,
      "/",
      kout,
      " - ",
      "inner iteration ",
      iter_in,
      "/",
      kin,
      " - ",
      round(100 * iter_out / kout),
      "%"
    )
    
    flag_in <- TRUE
    iter_in <- 0
    obj_in <- NULL
    while (flag_in) {
      if (val <= thresh) {
        flag_in <- FALSE
        flag_out <- FALSE
      }
      iter_in <- iter_in + 1
      if (iter_in >= kin) {
        flag_in <- FALSE
      }
      cat(
        "\r",
        "outer iteration ",
        iter_out,
        "/",
        kout,
        " - ",
        "inner iteration ",
        iter_in,
        "/",
        kin,
        " - ",
        round(100 * iter_out / kout),
        "%"
      )
      cores = detectCores()
      cl <- makeCluster(cores[1] - 1) #not to overload your computer
      registerDoParallel(cl)
      wrk_list <- foreach(i = 1:n, .packages = "mvtnorm") %dopar% {
        batch <- sample(1:Nc[c], nbatch)
        Stk1 <- sapply(batch, function(ib) {
          vec <- hatalpha * sapply(1:L, function(l) {
            dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
          })
          vec / sum(vec)
        })
        Stk0 <- sapply(batch, function(ib) {
          vec <- tmpalpha * sapply(1:L, function(l) {
            dmvnorm(Ylist[[c]][ib,], mean = tmpmu[l, ], sigma = hatSigma)
          })
          vec / sum(vec)
        })
        St2 <- lapply(1:nbatch, function(ib) {
          res <-
            as.matrix(Stk1[, ib]) %*% t(as.matrix(Ylist[[c]][batch[ib],]))
          res <-
            res - as.matrix(Stk0[, ib]) %*% t(as.matrix(Ylist[[c]][batch[ib],]))
          return(res)
        })
        St2 <- wrk_list[[c]]$St2 + Reduce('+', St2) / nbatch
        St1 <- wrk_list[[c]]$St1 + rowMeans(Stk1 - Stk0)
        Delta1 <- St1 - hatS1 - wrk_list[[c]]$V1
        Delta2 <- St2 - hatS2 - wrk_list[[c]]$V2
        V1 <- wrk_list[[c]]$V1 + alpha * Delta1
        V2 <- wrk_list[[c]]$V2 + alpha * Delta2
        list(
          St1 = St1,
          St2 = St2,
          Delta1 = Delta1,
          Delta2 = Delta2,
          V1 = V1,
          V2 = V2
        )
      }
      #stop cluster
      stopCluster(cl)
      # Update statistic hatS at central server
      Delta1list <- lapply(wrk_list, function(x)
        x$Delta1)
      Delta2list <- lapply(wrk_list, function(x)
        x$Delta2)
      H1 <- Reduce('+', Delta1list) / n + V1central
      H2 <- Reduce('+', Delta2list) / n + V2central
      val <- norm(H1, type = "2") ^ 2 + norm(H2, type = "F") ^ 2
      obj_in <- c(obj_in, val)
      hatS1 <- hatS1 + gamma * H1
      hatS2 <- hatS2 + gamma * H2
      V1central <- V1central + alpha * Reduce('+', Delta1list) / n
      V2central <- V2central + alpha * Reduce('+', Delta2list) / n
      # Compute map T(hatS)
      tmpalpha <- hatalpha
      tmpmu <- hatmu
      hatalpha <- hatS1 / sum(hatS1)
      alpha_counter <- alpha_counter+1
      alpha_list[[alpha_counter]] <- hatalpha
      hatmu <- sweep(hatS2, 1, hatS1, "/")
      hatSigma <- Sigma_ref
    }
    cores = detectCores()
    cl <- makeCluster(cores[1] - 1) #not to overload your computer
    registerDoParallel(cl)
    obj[[iter_out]] <- obj_in
    wrk_list <- foreach(i = 1:n, .packages = "mvtnorm") %dopar% {
      St1 <- sapply(1:Nc[c], function(ib) {
        vec <- hatalpha * sapply(1:L, function(l) {
          dmvnorm(Ylist[[c]][ib,], mean = hatmu[l, ], sigma = hatSigma)
        })
        vec / sum(vec)
      })
      St2 <- lapply(1:Nc[c], function(ib) {
        as.matrix(St1[, ib]) %*% t(as.matrix(Ylist[[c]][ib,]))
      })
      St2 <- Reduce('+', St2) / Nc[c]
      St1 <- rowMeans(St1)
      list(
        St1 = St1,
        St2 = St2,
        Delta1 = wrk_list[[c]]$Delta1,
        Delta2 = wrk_list[[c]]$Delta2,
        V1 = wrk_list[[c]]$V1,
        V2 = wrk_list[[c]]$V2
      )
    }
    #stop cluster
    stopCluster(cl)
  }
  return(list(
    alpha = alpha_list,
    mu = hatmu,
    Sigma = hatSigma,
    normH = obj
  ))
  
}
