# gaussian line model
# # generates n vs missclassification, delta vs missclassification plots
library(MASS)
library(dplyr)
library(tidyr)
library(ggplot2)
set.seed(123)

source("methods.R")

noise1 <- function(n, sig){
  mvrnorm(n, mu = c(0,0), Sigma = sig*diag(2))
}

# noisy version
# noise2 <- function(n){
#   mvrnorm(n, mu = c(0,0), Sigma = matrix(c(5,0.7,0.7,0.7), nrow = 2))
# }

# noiseless version
noise2 <- function(n){
  mvrnorm(n, mu = c(0,0), Sigma = matrix(c(5,0,0,0), nrow = 2))
}

n <- 3000
nrep = 32
sig = 1
# runs = expand.grid(delta = delta, n = round(10^seq(2,4, length.out=5)), rep = 1:nrep)
runs = expand.grid(delta = seq(0.8, 10, length.out = 12), n = n, rep = 1:nrep)
K <- 2

nstart = 20
niter = 200
result = NULL

methods = list()
methods[["L = 2"]] = function(x, K, niter, nstart) {
  kmeans(x, K, iter.max = niter, nstart = nstart)$cluster
}
methods[["L = 4"]] = function(x, K, niter, nstart) {
  kmeans(x, 4, iter.max = niter, nstart = nstart)$cluster
}
methods[["L = 10"]] = function(x, K, niter, nstart) {
  kmeans(x, 10, iter.max = niter, nstart = nstart)$cluster
}
# methods[["L = 30, 5"]] = function(x, K, niter, nstart) {
#   iterated_kmeans(x, K, L_schedule = c(30, 5),
#                   niter = niter, nstart = nstart)$cluster
# }
# methods[["multi-stage"]] = function(x, K, niter, nstart) {
#   iterated_kmeans(x, K, L_schedule = c(30, 10, K),
#                   niter = niter, nstart = nstart)$cluster
# }
mtd_names = names(methods)

res = do.call(rbind, parallel::mclapply(1:nrow(runs), function(j) {
  n = runs[j,"n"]
  delta = runs[j, "delta"]
  labels <- sort(sample(1:K, size = n, replace = T, prob = rep(1, K)))
  ns <- table(labels)
  
  c1 <- matrix(rep(c(0,delta), ns[1]), nrow = ns[1], byrow = T)
  c2 <- matrix(rep(c(0,0), ns[2]), nrow = ns[2], byrow = T)
  
  x1 = c1 + noise1(ns[1], sig)
  x2 = c2 + noise2(ns[2])
  x <- rbind(x1, x2)  

  do.call(rbind, lapply(seq_along(methods), function(mi) {
    dt = as.numeric(system.time( 
      zh <- methods[[mi]](x, K, niter = niter, nstart = nstart) 
    )["elapsed"])
    data.frame(rep = runs[j,"rep"], n = n, delta = delta,
               method = mtd_names[mi], 
               mis_rate = compute_mis_rate(zh, labels, n), 
               elapsed_time = dt)
  }))
}, mc.cores = 32))

res2 = res %>% 
  group_by(n, delta, method) %>% 
  summarize(avg_mis = mean(mis_rate), sd_mis = sd(mis_rate))
res2$method <- factor(res2$method,
                      levels = c("L = 2", "L = 4", "L = 10"))
cbbPalette <- c("#E69F00","#D55E00","#009E73", "#56B4E9", "#F0E442",  "#0072B2",  "#CC79A7")
res2 %>% ggplot(aes(delta, avg_mis, color = method)) + 
  geom_point(size = 3) + 
  geom_line(size =1.5) +
  ylab("Average Misclassification Rate")+
  # scale_x_continuous(trans="log10") +
  scale_colour_manual(values=cbbPalette)+
  geom_errorbar(aes(ymin=avg_mis-sd_mis, ymax=avg_mis+sd_mis), width=.05)+
  theme_minimal(base_size = 28) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.9, 0.9),
    legend.text = ggplot2::element_text(size=25),
  ) + 
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 2, keyheight = 1.25)) 

# ggsave("line_gaussian_noise_0.7_delta_mis_n=3000.pdf")
# ggsave("line_gaussian_delta_mis_n=3000.pdf")
