# Line passing circle model
# Missclassification rate vs. delta plots
library(MASS)
library(dplyr)
library(tidyr)
library(ggplot2)

source("methods.R")
source("signals.R")

set.seed(123)
nrep = 32
n <- 3000
K <- 2
nstart = 20
niter = 200
noise_sig = 1
runs = expand.grid(delta = 10^seq(log10(0.8), 1, length.out = 12), n = n, rep = 1:nrep)

methods = list()
methods[["L = 2"]] = function(x, K, niter, nstart) {
  kmeans(x, K, iter.max = niter, nstart = nstart)$cluster
}
methods[["L = 4"]] = function(x, K, niter, nstart) {
  kmeans(x, 4, iter.max = niter, nstart = nstart)$cluster
}
methods[["L = 10"]] = function(x, K, niter, nstart) {
  kmeans(x, 10, iter.max = niter, nstart = nstart)$cluster
}
# methods[["L = 10, 4"]] = function(x, K, niter, nstart) {
#   iterated_kmeans(x, K, L_schedule = c(10, 4),
#                   niter = niter, nstart = nstart)$cluster
# }
# methods[["multi-stage"]] = function(x, K, niter, nstart) {
#   iterated_kmeans(x, K, L_schedule = c(30, 10, K),
#                   niter = niter, nstart = nstart)$cluster
# }
mtd_names = names(methods)

res = do.call(rbind, parallel::mclapply(1:nrow(runs), function(j) {
# res = do.call(rbind, lapply(1:nrow(runs), function(j) {
  n = runs[j,"n"]
  delta = runs[j, "delta"]
  
  dat = line_circle(n, delta = delta, line_sig = 1, noise_sig = noise_sig, pri = c(1,1))
  x = as.matrix(dat[ , c("x","y","z")])
  
  do.call(rbind, lapply(seq_along(methods), function(mi) {
    dt = as.numeric(system.time( 
      zh <- methods[[mi]](x, K, niter = niter, nstart = nstart) 
    )["elapsed"])
    data.frame(rep = runs[j,"rep"], n = n, delta = delta,
               method = mtd_names[mi], 
               mis_rate = compute_mis_rate(zh, dat$label, n), 
               elapsed_time = dt)
  }))
}, mc.cores = 32))
# }))

res2 = res %>% 
  group_by(n, delta, method) %>% 
  summarize(avg_mis = mean(mis_rate), sd_mis = sd(mis_rate))
res2$method <- factor(res2$method,
                      levels = c("L = 2", "L = 4", "L = 10"))
cbbPalette <- c("#E69F00","#D55E00","#009E73", "#56B4E9", "#F0E442",  "#0072B2",  "#CC79A7")
res2 %>% ggplot(aes(delta, avg_mis, color = method)) + 
  geom_point(size = 3) + 
  geom_line(size =1.5) +
  ylab("Average Misclassification Rate")+
  # scale_x_continuous(trans="log10") +
  scale_colour_manual(values=cbbPalette)+
  geom_errorbar(aes(ymin=avg_mis-sd_mis, ymax=avg_mis+sd_mis), width=.05)+
  theme_minimal(base_size = 18) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.8, 0.8),
    legend.text = ggplot2::element_text(size=18),
  ) + 
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 2, keyheight = 1.25)) 

if (noise_sig > 0) {
  ggsave("line-circle-miss-delta-noisy.pdf", width = 7)
} else {
  ggsave("line-circle-miss-delta-noisless.pdf", width = 7)
}


