# circle-torus model 
# L vs missclassification plots
library(MASS)
library(dplyr)
library(tidyr)
library(ggplot2)

source("methods.R")
source("signals.R")

set.seed(123)

K <- 2
Ls = round(10^seq(0.5,1.8, length.out=12))
ns <- round(10^seq(2,4.5, length.out=5))
nrep = 32
nstart = 20
niter = 200
result = NULL
runs = expand.grid(r = 2, R = 10, noise_sig = 1, n = ns, rep = 1:nrep)

res = do.call(rbind, parallel::mclapply(1:nrow(runs), function(j) {
#res = do.call(rbind, lapply(1:nrow(runs), function(j) {
  n = runs[j,"n"]
  r = runs[j, "r"]
  R = runs[j, "R"]
  noise_sig = runs[j, "noise_sig"]
  
  # dat = line_circle(n, delta = delta, line_sig = 1, noise_sig = 0, pri = c(1,1))
  # dat = line_circle(n, delta = delta, line_sig = 1, noise_sig = 1, pri = c(1,1))
  dat = cicle_torus(n, R = R, r = r, noise_sig = noise_sig, pri = c(1,1))
  x = as.matrix(dat[ , c("x","y","z")])
  
  
  do.call(rbind, lapply(Ls, function(L) {
    dt = as.numeric(system.time( 
      zh <- kmeans(x, L, iter.max = niter, nstart = nstart)$cluster
    )["elapsed"])
    data.frame(rep = runs[j,"rep"], L = L,
               n = n, R = R, r = r, noise_sig = noise_sig,
               mis_rate = compute_mis_rate(zh, dat$label, n), 
               elapsed_time = dt)
  }))
}, mc.cores = 32))
# }))

res2 = res %>% 
  group_by(n, R, r, noise_sig, L) %>% 
  summarize(avg_mis = mean(mis_rate), sd_mis = sd(mis_rate))
res2$n_factor <- factor(paste("n =",res2$n))
res2$n_factor <- ordered(res2$n_factor, levels = paste("n =",ns))
cbbPalette <- c("#E69F00","#D55E00","#009E73", "#56B4E9", "#F0E442",  "#0072B2",  "#CC79A7")
res2 %>% ggplot(aes(L/sqrt(n*log(n)), avg_mis, color = n_factor)) +
# res2 %>% ggplot(aes(L, avg_mis, color = factor(n))) + 
# res2 %>% ggplot(aes(R, avg_mis, color = method)) + 
  geom_point(size = 3) + 
  geom_line(size =1.5) +
  ylab("Average Misclassification Rate")+
  scale_x_continuous(trans="log10") +
  scale_colour_manual(values=cbbPalette)+
  geom_errorbar(aes(ymin=avg_mis-sd_mis, ymax=avg_mis+sd_mis), width=.05)+
  theme_minimal(base_size = 28) +
  ggplot2::theme(
    legend.background = ggplot2::element_blank(),
    legend.title = ggplot2::element_blank(),
    legend.position = c(0.2, 0.2),
    legend.text = ggplot2::element_text(size=25),
  ) + 
  ggplot2::guides(colour = ggplot2::guide_legend(keywidth = 2, keyheight = 1.25)) 


ggsave("line-gauss.pdf", width = 7)
if (noise_sig > 0) {
  ggsave("circle-torus-miss-vs-L-noisy.pdf", width = 7)
} else {
  ggsave("circle-torus-miss-vs-L-noisless.pdf", width = 7)
}



