# rm(list = ls())
library(ExtDist)
library(ggplot2)
library(latex2exp)
library(ggtext)

path <- ""
res_path <- ""
plot_path <- ""

d <-  5
class <- 2

res <- read.csv(file = paste0(res_path, "octmnist_GDP_le_", class, "_",d, ".csv"))
# res <- read.csv(file = paste0(res_path, "GDP_alt_",d, ".csv"))
res <- res[1:10, ]
res_1 <- data.frame(eps = res[, 2], mean = res[, 3], std = res[, 5])
res_2 <- data.frame(eps = res[, 2], mean = res[, 4], std = res[, 6])

# res_1 <- data.frame(eps = res[, 2], mean = res[, 7], std = res[, 9])
# res_2 <- data.frame(eps = res[, 2], mean = res[, 8], std = res[, 10])

p <- ggplot(res_1, aes(x = eps, y = mean)) +
  geom_line(data = res_1, aes(color = "EWG")) +
  geom_ribbon(
    aes(ymin = mean - 1 * std, ymax = mean + 1 * std), fill = "red",
    data = res_1, alpha = 0.3) +
  geom_point(data = res_1, size = 2, aes(color = "EWG", shape = "EWG")) + 
  geom_line(data = res_2, aes(color = "RL")) +
  geom_ribbon(
    aes(ymin = mean - 1 * std, ymax = mean + 1 * std), fill = "blue",
    data = res_2, alpha = 0.3) +
  geom_point(data = res_2, size = 2, aes(color = "RL", shape = "RL")) + 
  scale_color_manual(name='',
                     breaks=c('EWG', "RL"),
                     values=c('EWG'='red', 'RL'='blue')) +
  scale_shape_manual(name='',
                     breaks=c('EWG', "RL"),
                     values=c('EWG'=16, 'RL'=17))+
  ggtitle(paste0("Class ", class, "; d = ", d*(d+1)/2)) +
  theme_classic() +
  theme(plot.title = element_text(size=20)) +
  theme(legend.position = c(0.8, 0.8)) +
  theme(text = element_text(size = 20)) +
  ylab("Riemannian distance") + 
  xlab(latex2exp::TeX("$\\mu$"))
p


ggsave(paste0("octmnist_GDP_le_", class, "_", d, ".png"), width = 10, height = 10, units = "cm", path = plot_path)