library(ggplot2)
library(latex2exp)
library(matrixStats)

# change input names based on experiment
adam <- as.numeric(read.csv("resnet_adam_1.csv", header=FALSE, stringsAsFactors=FALSE))
sgd <- as.numeric(read.csv("resnet_sgd_1.csv", header=FALSE, stringsAsFactors=FALSE))
splitsgd <- as.numeric(read.csv("resnet_splitsgd_1.csv", header=FALSE, stringsAsFactors=FALSE))
sgd_manual <- as.numeric(read.csv("resnet_sgd_3stage.csv", header=FALSE, stringsAsFactors=FALSE))
fdr <- as.numeric(read.csv("resnet_fdr_0.1.csv", header=FALSE, stringsAsFactors=FALSE))

df_res = data.frame(x = seq(1, 350), 
                     a = adam, s_m = sgd_manual,
                     s = sgd,
                     f = fdr,
                     ss = splitsgd)


# PLOTS
cols <- c("1_bl_l" = "#9ecae1", "2_re_l" = "#fc9272", "3_gr_l"="#a1d99b", "4_bk_l"="#bdbdbd",
          "5_bk"="#000000", "6_bl"="#045a8d", "7_re"="#bd0026", "8_gr"="#007e00")
my_labs = list(TeX('SGD $\\eta = 0.1$'), TeX('SGD $\\eta = 0.01$'),
               TeX('Adam $\\eta = 0.0003$'), 
               TeX('FDR $\\eta = 0.1$'), TeX('FDR $\\eta = 0.01$'),
               TeX('SplitSGD $\\eta = 0.1$'), TeX('SplitSGD $\\eta = 0.01$'),
               TeX('SGD manual decay'))
jump = 24

ggplot(df_res) +
  geom_line(aes(x=x, y=s, colour = '1_bl_l'), size = 1) +              
  geom_point(data = df_res[seq(7, 350, by=jump),], aes(x=x, y=s, colour = '1_bl_l'), shape = 16, size = 5) +
  geom_line(aes(x=x, y=a, colour = '3_gr_l'), size = 1) +              
  geom_point(data = df_res[seq(1, 350, by=jump), ], aes(x=x, y=a, colour = '3_gr_l'), shape = 17, size = 6) +
  geom_line(aes(x=x, y=f, colour = '4_bk_l'), size = 1) +              
  geom_point(data = df_res[seq(7, 350, by=jump),], aes(x=x, y=f, colour = '4_bk_l'), shape = 18, size = 5) +
  geom_line(aes(x=x, y=ss, colour = '6_bl'), size = 1) +              
  geom_point(data = df_res[seq(13, 350, by=jump),], aes(x=x, y=ss, colour = '6_bl'), shape = 16, size = 5) +
  geom_line(aes(x=x, y=s_m, colour = '8_gr'), size = 1) +                           
  geom_point(data = df_res[seq(4, 350, by=jump),], aes(x=x, y=s_m, colour = '8_gr'), shape = 17, size = 5) +
  xlab('Epochs') + ylab(TeX('Test Accuracy')) +
  ggtitle('Resnet on CIFAR-10') +
  scale_colour_manual(values=cols, labels = my_labs) +
  scale_fill_manual(values=cols) +
  ylim(80, 94) +
  theme_bw() +
  theme(legend.position = c(0.70, 0.3), 
        legend.background = element_rect(colour = 'black'),
        legend.key.size=unit(1,"cm"), legend.text=element_text(size=32),
        legend.title=element_blank(), 
        axis.text.x = element_text(size=32),
        axis.text.y = element_text(size=32),
        axis.title = element_text(size=32),
        plot.title = element_text(hjust = 0.5, size = 35)) +
  guides(fill=FALSE, color = guide_legend(override.aes = list(shape = c(16,17,18,16,17)))) 

