# Circle-torus model
# Missclassification rate vs. delta plots
library(dplyr)
library(tidyr)
library(ggplot2)
library(MASS)

source("methods.R")
source("signals.R")

set.seed(123)

noise_sig = 0.01
R = 10
r = 1
K <- 2
n <- 3000
nstart = 20
niter = 200
result = NULL
# noise_sig = 1 to create noisy version 
noise_sig = 0
R = 3
nrep = 32
runs = expand.grid(r = seq(0.1, 10, length.out = 12), R = R, noise_sig = noise_sig, n = n, rep = 1:nrep)

methods = list()
methods[["L = 2"]] = function(x, K, niter, nstart) {
  kmeans(x, K, iter.max = niter, nstart = nstart)$cluster
}
methods[["L = 4"]] = function(x, K, niter, nstart) {
  kmeans(x, 4, iter.max = niter, nstart = nstart)$cluster
}
methods[["L = 10"]] = function(x, K, niter, nstart) {
  kmeans(x, 10, iter.max = niter, nstart = nstart)$cluster
}
# methods[["L = 20"]] = function(x, K, niter, nstart) {
#   kmeans(x, 20, iter.max = niter, nstart = nstart)$cluster
# }
# methods[["L = 50"]] = function(x, K, niter, nstart) {
#   kmeans(x, 50, iter.max = niter, nstart = nstart)$cluster
# }
# methods[["L = 100"]] = function(x, K, niter, nstart) {
#   kmeans(x, 100, iter.max = niter, nstart = nstart)$cluster
# }

# methods[["L = 10, 4"]] = function(x, K, niter, nstart) {
#   iterated_kmeans(x, K, L_schedule = c(10, 4),
#                   niter = niter, nstart = nstart)$cluster
# }
# methods[["multi-stage"]] = function(x, K, niter, nstart) {
#   iterated_kmeans(x, K, L_schedule = c(30, 10, K),
#                   niter = niter, nstart = nstart)$cluster
# }
mtd_names = names(methods)

res = do.call(rbind, parallel::mclapply(1:nrow(runs), function(j) {
  #res = do.call(rbind, lapply(1:nrow(runs), function(j) {
  n = runs[j,"n"]
  r = runs[j, "r"]
  R = runs[j, "R"]
  noise_sig = runs[j, "noise_sig"]
  
  # dat = line_circle(n, delta = delta, line_sig = 1, noise_sig = 0, pri = c(1,1))
  # dat = line_circle(n, delta = delta, line_sig = 1, noise_sig = 1, pri = c(1,1))
  dat = cicle_torus(n, R = R, r = r, noise_sig = noise_sig, pri = c(1,1))
  x = as.matrix(dat[ , c("x","y","z")])
  
  
  do.call(rbind, lapply(seq_along(methods), function(mi) {
    dt = as.numeric(system.time( 
      zh <- methods[[mi]](x, K, niter = niter, nstart = nstart) 
    )["elapsed"])
    data.frame(rep = runs[j,"rep"], n = n, R = R, r = r, noise_sig = noise_sig,
               method = mtd_names[mi], 
               mis_rate = compute_mis_rate(zh, dat$label, n), 
               elapsed_time = dt)
  }))
}, mc.cores = 32))
# }))

res2 = res %>% 
  group_by(n, R, r, noise_sig, method) %>% 
  summarize(avg_mis = mean(mis_rate), sd_mis = sd(mis_rate))
res2$method <- factor(res2$method,
                      levels = names(methods)) # c("L = 2", "L = 4", "L = 10"))
cbbPalette <- c("#E69F00","#D55E00","#009E73", "#56B4E9", "#F0E442",  "#0072B2",  "#CC79A7")
res2 %>% ggplot(aes(r, avg_mis, color = method)) + 
# res2 %>% ggplot(aes(R, avg_mis, color = method)) + 
  geom_point(size = 3) + 
  geom_line(size =1.5) +
  ylab("Average Misclassification Rate")+
  #scale_y_continuous(trans="log10") +
  scale_colour_manual(values=cbbPalette)+
  geom_errorbar(aes(ymin=avg_mis-sd_mis, ymax=avg_mis+sd_mis), width=.05)+
  theme_minimal(base_size = 18) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.9, 0.9),
    legend.text = ggplot2::element_text(size=18),
  ) + 
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 2, keyheight = 1.25)) 

if (noise_sig > 0) {
  ggsave("circle-torus-miss-delta-noisy.pdf", width = 7)
} else {
  ggsave("circle-torus-miss-delta-noisless.pdf", width = 7)
}



