library(ggplot2)
library(latex2exp)
library(pracma)
library(tilting)
library(matrixStats)

############################################################
################### OPTIMIZATION METHODS ###################
############################################################

############################################################
###################### SPLITSGD
############################################################

SplitSGD <- function(x, y, t1, ep_max, eta, w, q, gamma, theta0, my_grad) {
  n <- nrow(x)
  d <- ncol(x)
  stepsize = eta
  diagnostic_truth = c()
  all_stepsize = eta
  stop = FALSE
  
  ep = 0
  theta <- matrix(theta0, ncol = d, nrow = 1)
  theta_temp = theta0
  while(ep < ep_max){
    for (i in 1:t1){
      idx <- sample(n, n)
      for(id in 1:n){
        theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      }
      theta = rbind(theta, theta_temp)
      all_stepsize = c(all_stepsize, stepsize)
      ep = ep + 1
      if(ep == ep_max){
        stop = TRUE
        break
      }
    }
    if(stop){
      break
    }
    
    theta_temp1 = theta_temp
    theta_temp2 = theta_temp
    theta_in1 = theta_temp1
    theta_in2 = theta_temp2
    Qi = c()
    idx <- sample(n, n)
    for(id in 1:n){
      if(id%%2 == 1){
        theta_temp1 = theta_temp1 - stepsize*my_grad(theta_temp1, x[idx[id], ], y[idx[id]])
      }
      if(id%%2 == 0){
        theta_temp2 = theta_temp2 - stepsize*my_grad(theta_temp2, x[idx[id], ], y[idx[id]])
      }
      
      if(id%%(n/w) == (n/w)-1){
        Qi = c(Qi, t(theta_temp1-theta_in1)%*%(theta_temp2-theta_in2))
        theta_in1 = theta_temp1
        theta_in2 = theta_temp2
      }
      
    }
    if (sum(Qi < 0) > q*length(Qi)){
      diagnostic_truth = c(diagnostic_truth, TRUE)
      t1 = as.integer(t1/gamma)
      stepsize = stepsize*gamma
    }
    if (sum(Qi < 0) <= q*length(Qi)){
      diagnostic_truth = c(diagnostic_truth, FALSE)
    }
    
    theta_temp = (theta_temp1 + theta_temp2)/2
    theta = rbind(theta, theta_temp)
    
    
    ep = ep + 1
    if(ep == ep_max){
      break
    }
    
    all_stepsize = c(all_stepsize, stepsize)
    
  }
  out <- list()
  out$theta = theta
  out$diagnostic_truth = diagnostic_truth
  out$stepsize = all_stepsize
  out
}


############################################################
###################### SGD
############################################################

my_sgd <- function(x, y, ep_max, eta, alpha, theta0, my_grad){
  n = nrow(x)
  d <- ncol(x)
  all_stepsize = eta
  theta = theta0
  iter = 1
  
  theta_temp = theta0
  for (i in 1:ep_max){
    idx <- sample(n, n)
    for(id in 1:n){
      stepsize = eta/(iter^alpha)
      theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      iter = iter + 1
    }
    all_stepsize = c(all_stepsize, stepsize)
    theta = rbind(theta, theta_temp)
  }
  
  out = list()
  out$theta = theta
  out$step_size = all_stepsize
  out
}


############################################################
###################### SGD_HALF
############################################################

my_sgd_half <- function(x, y, t1, ep_max, eta, theta0, my_grad){
  n = nrow(x)
  d <- ncol(x)
  theta = theta0
  n_decay = as.integer(log2(ep_max/t1 + 1))
  all_stepsize = eta
  ep = 0
  theta_temp = theta0
  if(n_decay > 0){
    for(k in 1:n_decay){
      stepsize = eta/(2^(k-1))
      t_next = t1*(2^(k-1))
      for (i in 1:t_next){
        ep = ep + 1
        idx <- sample(n, n)
        for(id in 1:n){
          theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
        }
        all_stepsize = c(all_stepsize, stepsize)
        theta = rbind(theta, theta_temp)
      }
    }
  }
  ep_diff = ep_max - ep
  if(ep_diff > 0){
    stepsize = eta/(2^n_decay)
    for(j in 1:ep_diff){
      idx <- sample(n, n)
      for(id in 1:n){
        theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      }
      all_stepsize = c(all_stepsize, stepsize)
      theta = rbind(theta, theta_temp)
    }
  }
  out = list()
  out$theta = theta
  out$stepsize = all_stepsize
  out
}


############################################################
######################### LOSSES ###########################
############################################################

loss_lm = function(x, y, theta){
  n = length(y)
  return((sum((y - x%*%t(theta))^2))/n)
}


loss_log = function(x, y, theta){
  n = length(y)
  return((sum(-y*(x%*%t(theta)) + log(1+exp(x%*%t(theta)))))/n)
}


############################################################
####################### SIMULATIONS ########################
############################################################

############################################################
####################### LINEAR REGRESSION
############################################################
n <- 1000
d <-20
sigma <- 1
theta_star <- 5*exp(-0.5*seq(1,d))
x <- matrix(rnorm(n * d, sd = 1), n, d)
y <- as.numeric(x %*% theta_star + rnorm(n, 0, sigma))

getGradient <- function(th, x1, y1) { x1 * (dot(th, x1) - y1) }



B = 10
t1 = 4
theta0 = rep(0, d)
gamma = 0.5
ep_max = 100

eta_large = 3e-4
eta_medium = 1e-4
eta_small = 3e-5

ws = c(10, 20, 40)
qs = c(0.35, 0.4, 0.45)

med_dist_large = numeric(9)
med_dist_medium = numeric(9)
med_dist_small = numeric(9)
up_dist_large = numeric(9)
up_dist_medium = numeric(9)
up_dist_small = numeric(9)
low_dist_large = numeric(9)
low_dist_medium = numeric(9)
low_dist_small = numeric(9)

for(j1 in 1:3){
  w = ws[j1]
  for(j2 in 1:3){
    q = qs[j2]
    dist_large_temp = matrix(NA, nrow = B, ncol = ep_max+1)
    dist_medium_temp = matrix(NA, nrow = B, ncol = ep_max+1)
    dist_small_temp = matrix(NA, nrow = B, ncol = ep_max+1)
    for(b in 1:B){

      fit_large <- SplitSGD(x, y, t1 = t1, ep_max = ep_max, eta = eta_large, gamma = gamma,
                            w = w, q = q, theta0 = theta0, my_grad = getGradient)
      fit_medium <- SplitSGD(x, y, t1 = t1, ep_max = ep_max, eta = eta_medium, gamma = gamma,
                    w = w, q = q, theta0 = theta0, my_grad = getGradient)
      fit_small <- SplitSGD(x, y, t1 = t1, ep_max = ep_max, eta = eta_small, gamma = gamma,
                            w = w, q = q, theta0 = theta0, my_grad = getGradient)
      
      dist_large_temp[b,] = log(loss_lm(x, y, t(fit_large$theta[ep_max+1,])))
      dist_medium_temp[b,] = log(loss_lm(x, y, t(fit_medium$theta[ep_max+1,])))
      dist_small_temp[b,] = log(loss_lm(x, y, t(fit_small$theta[ep_max+1,])))
      
      print(paste0('w = ', w, ', q = ', q, ' and iter = ', b))
    }
    med_dist_large[j2 + 3*(j1-1)] = median(dist_large_temp[,ep_max+1])
    up_dist_large[j2 + 3*(j1-1)] = quantile(dist_large_temp[,ep_max+1], probs = 0.975)
    low_dist_large[j2 + 3*(j1-1)] = quantile(dist_large_temp[,ep_max+1], probs = 0.025)
    med_dist_medium[j2 + 3*(j1-1)] = median(dist_medium_temp[,ep_max+1])
    up_dist_medium[j2 + 3*(j1-1)] = quantile(dist_medium_temp[,ep_max+1], probs = 0.975)
    low_dist_medium[j2 + 3*(j1-1)] = quantile(dist_medium_temp[,ep_max+1], probs = 0.025)
    med_dist_small[j2 + 3*(j1-1)] = median(dist_small_temp[,ep_max+1])
    up_dist_small[j2 + 3*(j1-1)] = quantile(dist_small_temp[,ep_max+1], probs = 0.975)
    low_dist_small[j2 + 3*(j1-1)] = quantile(dist_small_temp[,ep_max+1], probs = 0.025)
    
  }
}


df_sens_lin = data.frame(x = 1:9, 
                    ml = med_dist_large, ul = up_dist_large, ll = low_dist_large,
                    mm = med_dist_medium, um = up_dist_medium, lm = low_dist_medium,
                    ms = med_dist_small, us = up_dist_small, ls = low_dist_small)

#load('ICLR2021_Dataframe_linear_sensitivity_analysis.Rda')
#ep_max = 100
cols <- c("c4"="#3dc93d","c2"="#ff0033","c3"="#0000ff", "c1"="#000000")
my.labs = list(TeX('$\\eta = 3e-4$'), TeX('$\\eta = 1e-4$'), TeX('$\\eta = 3e-5$'))
ggplot(df_sens_lin, aes(x=x)) + 
  geom_line(aes(y=ml, colour = 'c1'), size = 1.5) +
  geom_ribbon(aes(ymin = ll, ymax = ul, fill = 'c1'), alpha = 0.2) +
  geom_point(data = df_sens_lin[c(1,4,7),], aes(x=x, y=ml, colour = 'c1'), shape = 18, size = 5) +
  geom_line(aes(y=mm, colour = 'c2'), size = 1.5) +
  geom_ribbon(aes(ymin = lm, ymax = um, fill = 'c2'), alpha = 0.2) +
  geom_point(data = df_sens_lin[c(2,5,8),], aes(x=x, y=mm, colour = 'c2'), shape = 17, size = 5) +
  geom_line(aes(y=ms, colour = 'c3'), size = 1.5) +
  geom_ribbon(aes(ymin = ls, ymax = us, fill = 'c3'), alpha = 0.2) +
  geom_point(data = df_sens_lin[c(3,6,9),], aes(x=x, y=ms, colour = 'c3'), shape = 16, size = 5) +
  labs(title = 'Linear Regression') +
  ylab(paste0('log(loss) after ', ep_max, ' epochs')) +
  scale_colour_manual(values=cols, labels = my.labs) +
  scale_fill_manual(values=cols) +
  scale_x_continuous(breaks = 1:9, 
                     labels = c('(10, 0.35)', '(10, 0.40)', '(10, 0.45)',
                                '(20, 0.35)', '(20, 0.40)', '(20, 0.45)',
                                '(40, 0.35)', '(40, 0.40)', '(40, 0.45)')) +
  theme_bw() +
  theme(legend.position = c(0.2, 0.8), 
        legend.background = element_rect(colour = 'black'),
        legend.key.size=unit(1.5,"cm"), 
        legend.text=element_text(size=30),
        legend.title=element_blank(),
        plot.title = element_text(hjust = 0.5, size = 40),
        axis.text.x = element_text(size=25, angle = 90),
        axis.text.y = element_text(size=30, angle = 90),
        axis.title.y = element_text(size=30),
        axis.title.x = element_blank()) +
  guides(fill=FALSE, color = guide_legend(override.aes = list(shape = c(18,17,16))))


#save(df_sens_lin, file='Neurips20_Dataframe_linear_sensitivity_analysis.Rda')
#ggsave('ICLR2021_linear_regression_sensitivity_analysis.png', width = 30, height = 30, units = 'cm', dpi = 300)







############################################################
####################### LOGISTIC REGRESSION
############################################################
n <- 1000
d <- 20
sigma <- 1
theta_star <- 5*exp(-0.5*seq(1,d))
x <- matrix(rnorm(n * d, sd = 1), n, d)
pr = 1/(1+exp(-(x %*% theta_star)))
y <- rbinom(n, 1, pr)
getGradient <- function(th, x1, y1) { -y1*x1 + x1/(1 + exp(-dot(th, x1))) }



B = 10
t1 = 4
theta0 = rep(0, d)
gamma = 0.5
ep_max = 100

eta_large = 1e-2
eta_medium = 3e-3
eta_small = 1e-3

ws = c(10, 20, 40)
qs = c(0.35, 0.4, 0.45)

med_dist_large = numeric(9)
med_dist_medium = numeric(9)
med_dist_small = numeric(9)
up_dist_large = numeric(9)
up_dist_medium = numeric(9)
up_dist_small = numeric(9)
low_dist_large = numeric(9)
low_dist_medium = numeric(9)
low_dist_small = numeric(9)

for(j1 in 1:3){
  w = ws[j1]
  for(j2 in 1:3){
    q = qs[j2]
    dist_large_temp = matrix(NA, nrow = B, ncol = ep_max+1)
    dist_medium_temp = matrix(NA, nrow = B, ncol = ep_max+1)
    dist_small_temp = matrix(NA, nrow = B, ncol = ep_max+1)
    for(b in 1:B){
      
      fit_large <- SplitSGD(x, y, t1 = t1, ep_max = ep_max, eta = eta_large, gamma = gamma,
                            w = w, q = q, theta0 = theta0, my_grad = getGradient)
      fit_medium <- SplitSGD(x, y, t1 = t1, ep_max = ep_max, eta = eta_medium, gamma = gamma,
                             w = w, q = q, theta0 = theta0, my_grad = getGradient)
      fit_small <- SplitSGD(x, y, t1 = t1, ep_max = ep_max, eta = eta_small, gamma = gamma,
                            w = w, q = q, theta0 = theta0, my_grad = getGradient)
      
      dist_large_temp[b,] = log(loss_log(x, y, t(fit_large$theta[ep_max+1,])))
      dist_medium_temp[b,] = log(loss_log(x, y, t(fit_medium$theta[ep_max+1,])))
      dist_small_temp[b,] = log(loss_log(x, y, t(fit_small$theta[ep_max+1,])))
      
      print(paste0('w = ', w, ', q = ', q, ' and iter = ', b))
    }
    med_dist_large[j2 + 3*(j1-1)] = median(dist_large_temp[,ep_max+1])
    up_dist_large[j2 + 3*(j1-1)] = quantile(dist_large_temp[,ep_max+1], probs = 0.975)
    low_dist_large[j2 + 3*(j1-1)] = quantile(dist_large_temp[,ep_max+1], probs = 0.025)
    med_dist_medium[j2 + 3*(j1-1)] = median(dist_medium_temp[,ep_max+1])
    up_dist_medium[j2 + 3*(j1-1)] = quantile(dist_medium_temp[,ep_max+1], probs = 0.975)
    low_dist_medium[j2 + 3*(j1-1)] = quantile(dist_medium_temp[,ep_max+1], probs = 0.025)
    med_dist_small[j2 + 3*(j1-1)] = median(dist_small_temp[,ep_max+1])
    up_dist_small[j2 + 3*(j1-1)] = quantile(dist_small_temp[,ep_max+1], probs = 0.975)
    low_dist_small[j2 + 3*(j1-1)] = quantile(dist_small_temp[,ep_max+1], probs = 0.025)
    
  }
}


df_sens_log = data.frame(x = 1:9, 
                         ml = med_dist_large, ul = up_dist_large, ll = low_dist_large,
                         mm = med_dist_medium, um = up_dist_medium, lm = low_dist_medium,
                         ms = med_dist_small, us = up_dist_small, ls = low_dist_small)
#load('ICLR2021_Dataframe_logistic_sensitivity_analysis.Rda')
#ep_max = 100
cols <- c("c4"="#3dc93d","c2"="#ff0033","c3"="#0000ff", "c1"="#000000")
my.labs = list(TeX('$\\eta = 1e-2$'), TeX('$\\eta = 3e-3$'), TeX('$\\eta = 1e-3$'))
ggplot(df_sens_log, aes(x=x)) + 
  geom_line(aes(y=ml, colour = 'c1'), size = 1.5) +
  geom_ribbon(aes(ymin = ll, ymax = ul, fill = 'c1'), alpha = 0.2) +
  geom_point(data = df_sens_log[c(1,4,7),], aes(x=x, y=ml, colour = 'c1'), shape = 18, size = 5) +
  geom_line(aes(y=mm, colour = 'c2'), size = 1.5) +
  geom_ribbon(aes(ymin = lm, ymax = um, fill = 'c2'), alpha = 0.2) +
  geom_point(data = df_sens_log[c(2,5,8),], aes(x=x, y=mm, colour = 'c2'), shape = 17, size = 5) +
  geom_line(aes(y=ms, colour = 'c3'), size = 1.5) +
  geom_ribbon(aes(ymin = ls, ymax = us, fill = 'c3'), alpha = 0.2) +
  geom_point(data = df_sens_log[c(3,6,9),], aes(x=x, y=ms, colour = 'c3'), shape = 16, size = 5) +
  labs(title = 'Logistic Regression') +
  ylab(paste0('log(loss) after ', ep_max, ' epochs')) +
  scale_colour_manual(values=cols, labels = my.labs) +
  scale_fill_manual(values=cols) +
  scale_x_continuous(breaks = 1:9, 
                     labels = c('(10, 0.35)', '(10, 0.40)', '(10, 0.45)',
                                '(20, 0.35)', '(20, 0.40)', '(20, 0.45)',
                                '(40, 0.35)', '(40, 0.40)', '(40, 0.45)')) +
  theme_bw() +
  theme(legend.position = c(0.2, 0.8), 
        legend.background = element_rect(colour = 'black'),
        legend.key.size=unit(1.5,"cm"), 
        legend.text=element_text(size=30),
        legend.title=element_blank(),
        plot.title = element_text(hjust = 0.5, size = 40),
        axis.text.x = element_text(size=25, angle = 90),
        axis.text.y = element_text(size=30, angle = 90),
        axis.title.y = element_text(size=30),
        axis.title.x = element_blank()) +
  guides(fill=FALSE, color = guide_legend(override.aes = list(shape = c(18,17,16))))


#save(df_sens_log, file='Neurips20_Dataframe_logistic_sensitivity_analysis.Rda')
#ggsave('ICLR2021_logistic_regression_sensitivity_analysis.png', width = 30, height = 30, units = 'cm', dpi = 300)

