source("~/GMM_mnist_server/GMM_algos_mnist.R")

load(file = "~/GMM_mnist_server/data/mnist_reduced.Rdata")
mnist_reduce <- mnist$data #load data after preprocessing

# First expe: let the algo run and display ||H||^2 at each iter
# Assignment of individuals to local servers
p <- ncol(mnist_reduce)
N <- nrow(mnist_reduce)
n <- 100 # number of servers
N <- 7 * 1e4
# reorder lines randomly
Y <- mnist_reduce[sample(1:nrow(mnist_reduce), N),]
ngroup <- N / n
# assign data points to servers
groups <- as.factor(rep(1:n, each = ngroup))
Nc <- aggregate(rep(1, N), list(groups), sum)[, 2]
L <- 10
p <- 20
Ylist <- list()
for (c in 1:n) {
  Ylist[[c]] <- Y[which(groups == c),]
}

# initialize parameters
mu_init <- Y[sample(1:N, L), ]
Sigma_init <-
  t(sweep(Y, 2, colMeans(Y))) %*% sweep(Y, 2, colMeans(Y)) / N

# apply FedEM
tt <- FedEM_gmm(
  Ylist,
  groups,
  L = 10,
  maxiter = 1e4,
  thresh = 1e-20,
  gamma = 1e-3,
  nbatch = 20,
  mu_init = mu_init,
  Sigma_init = Sigma_init
)
fileFedEM <-
  paste('~/GMM_mnist_server/results/FedEM_mnist_normH.txt',
        sep = "")
write.table(tt$normH, file = fileFedEM)
write.table(
  do.call(rbind, tt$alpha),
  file = paste(
    '~/GMM_mnist_server/results/FedEM_mnist_alpha.txt',
    sep = ""
  )
)

# apply VR-FedEM
ttSpid <-
  FedSpiderEM_gmm(
    Ylist = Ylist,
    groups = groups,
    L = 10,
    kout = 150,
    kin = 35,
    thresh = 1e-20,
    gamma = 1e-3,
    nbatch = 20,
    mu_init = mu_init,
    Sigma_init = Sigma_init
  )
fileFedSpidEM <-
  paste('~/GMM_mnist_server/results/FedSpidEM_mnist_normH_simu.txt',
        sep = "")
write.table(unlist(ttSpid$normH), file = fileFedSpidEM)
write.table(
  do.call(rbind, ttSpid$alpha),
  file = paste(
    '~/GMM_mnist_server/results/FedSpidEM_mnist_alpha_simu.txt',
    sep = ""
  )
)

