library(ggplot2)
library(latex2exp)
library(pracma)
library(tilting)
library(matrixStats)

############################################################
################### OPTIMIZATION METHODS ###################
############################################################

############################################################
###################### SPLITSGD
############################################################

SplitSGD <- function(x, y, t1, ep_max, eta, w, q, gamma, theta0, my_grad) {
  n <- nrow(x)
  d <- ncol(x)
  stepsize = eta
  diagnostic_truth = c()
  all_stepsize = eta
  stop = FALSE
  
  ep = 0
  theta <- matrix(theta0, ncol = d, nrow = 1)
  theta_temp = theta0
  while(ep < ep_max){
    for (i in 1:t1){
      idx <- sample(n, n)
      for(id in 1:n){
        theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      }
      theta = rbind(theta, theta_temp)
      all_stepsize = c(all_stepsize, stepsize)
      ep = ep + 1
      if(ep == ep_max){
        stop = TRUE
        break
      }
    }
    if(stop){
      break
    }
    
    theta_temp1 = theta_temp
    theta_temp2 = theta_temp
    theta_in1 = theta_temp1
    theta_in2 = theta_temp2
    Qi = c()
    idx <- sample(n, n)
    for(id in 1:n){
      if(id%%2 == 1){
        theta_temp1 = theta_temp1 - stepsize*my_grad(theta_temp1, x[idx[id], ], y[idx[id]])
      }
      if(id%%2 == 0){
        theta_temp2 = theta_temp2 - stepsize*my_grad(theta_temp2, x[idx[id], ], y[idx[id]])
      }
        
      if(id%%(n/w) == (n/w)-1){
        Qi = c(Qi, t(theta_temp1-theta_in1)%*%(theta_temp2-theta_in2))
        theta_in1 = theta_temp1
        theta_in2 = theta_temp2
      }
        
    }
    if (sum(Qi < 0) > q*length(Qi)){
      diagnostic_truth = c(diagnostic_truth, TRUE)
      t1 = as.integer(t1/gamma)
      stepsize = stepsize*gamma
    }
    if (sum(Qi < 0) <= q*length(Qi)){
      diagnostic_truth = c(diagnostic_truth, FALSE)
    }
    
    theta_temp = (theta_temp1 + theta_temp2)/2
    theta = rbind(theta, theta_temp)
    
    
    ep = ep + 1
    if(ep == ep_max){
      break
    }
    
    all_stepsize = c(all_stepsize, stepsize)

  }
  out <- list()
  out$theta = theta
  out$diagnostic_truth = diagnostic_truth
  out$stepsize = all_stepsize
  out
}


############################################################
###################### SGD
############################################################

my_sgd <- function(x, y, ep_max, eta, alpha, theta0, my_grad){
  n = nrow(x)
  d <- ncol(x)
  all_stepsize = eta
  theta = theta0
  iter = 1

  theta_temp = theta0
  for (i in 1:ep_max){
    idx <- sample(n, n)
    for(id in 1:n){
      stepsize = eta/(iter^alpha)
      theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      iter = iter + 1
    }
    all_stepsize = c(all_stepsize, stepsize)
    theta = rbind(theta, theta_temp)
  }
  
  out = list()
  out$theta = theta
  out$step_size = all_stepsize
  out
}


############################################################
###################### SGD_HALF
############################################################

my_sgd_half <- function(x, y, t1, ep_max, eta, theta0, my_grad){
  n = nrow(x)
  d <- ncol(x)
  theta = theta0
  n_decay = as.integer(log2(ep_max/t1 + 1))
  all_stepsize = eta
  ep = 0
  theta_temp = theta0
  if(n_decay > 0){
    for(k in 1:n_decay){
      stepsize = eta/(2^(k-1))
      t_next = t1*(2^(k-1))
      for (i in 1:t_next){
        ep = ep + 1
        idx <- sample(n, n)
        for(id in 1:n){
          theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
        }
        all_stepsize = c(all_stepsize, stepsize)
        theta = rbind(theta, theta_temp)
      }
    }
  }
  ep_diff = ep_max - ep
  if(ep_diff > 0){
    stepsize = eta/(2^n_decay)
    for(j in 1:ep_diff){
      idx <- sample(n, n)
      for(id in 1:n){
        theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      }
      all_stepsize = c(all_stepsize, stepsize)
      theta = rbind(theta, theta_temp)
    }
  }
  out = list()
  out$theta = theta
  out$stepsize = all_stepsize
  out
}


############################################################
######################### LOSSES ###########################
############################################################

loss_lm = function(x, y, theta){
  n = length(y)
  return((sum((y - x%*%t(theta))^2))/n)
}


loss_log = function(x, y, theta){
  n = length(y)
  return((sum(-y*(x%*%t(theta)) + log(1+exp(x%*%t(theta)))))/n)
}



############################################################
####################### SIMULATIONS ########################
############################################################

############################################################
####################### LINEAR REGRESSION
############################################################

n <- 1000
d <-20
sigma <- 1
theta_star <- 5*exp(-0.5*seq(1,d))
x <- matrix(rnorm(n * d, sd = 1), n, d)
y <- as.numeric(x %*% theta_star + rnorm(n, 0, sigma))

getGradient <- function(th, x1, y1) { x1 * (dot(th, x1) - y1) }



B = 10
nw = 20
q = 0.4
t1 = 4
theta0 = rep(0, d)
gamma = 0.5
ep_max = 100

etas = c(1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2)

med_dist_split = numeric(length(etas))
med_dist_bot = numeric(length(etas))
med_dist_dec = numeric(length(etas))
med_dist_cost = numeric(length(etas))
up_dist_split = numeric(length(etas))
up_dist_bot = numeric(length(etas))
up_dist_dec = numeric(length(etas))
up_dist_cost = numeric(length(etas))
low_dist_split = numeric(length(etas))
low_dist_bot = numeric(length(etas))
low_dist_dec = numeric(length(etas))
low_dist_cost = numeric(length(etas))

for(j in 1:length(etas)){

  dist1_temp = matrix(NA, nrow = B, ncol = ep_max+1)
  dist2_temp = matrix(NA, nrow = B, ncol = ep_max+1)
  dist3_temp = matrix(NA, nrow = B, ncol = ep_max+1)
  dist4_temp = matrix(NA, nrow = B, ncol = ep_max+1)
  eta = etas[j]

  for(b in 1:B){
    fit <- SplitSGD(x, y, t1 = t1, ep_max = ep_max, eta = eta, gamma = gamma,
                    w = nw, q = q, theta0 = theta0, my_grad = getGradient)
    
    bot = my_sgd_half(x, y, t1=t1, ep_max=ep_max, eta=eta, theta0 = theta0, my_grad = getGradient)
    
    s1 <- my_sgd(x, y, ep_max = ep_max, eta = eta*20, alpha = 0.5, theta0 = theta0, my_grad = getGradient)
    s2 <- my_sgd(x, y, ep_max = ep_max, eta = eta, alpha = 0, theta0 = theta0, my_grad = getGradient)
    
    mydist1 = log(loss_lm(x, y, t(fit$theta[ep_max+1,])))
    mydist2 = log(loss_lm(x, y, t(bot$theta[ep_max+1,])))
    mydist3 = log(loss_lm(x, y, t(s1$theta[ep_max+1,])))
    mydist4 = log(loss_lm(x, y, t(s2$theta[ep_max+1,])))
    
    dist1_temp[b,] = mydist1
    dist2_temp[b,] = mydist2
    dist3_temp[b,] = mydist3
    dist4_temp[b,] = mydist4
    
    print(paste0('stepsize = ', eta, ' and iter = ', b))
  }
  
  med_dist_split[j] = median(dist1_temp[,ep_max+1])
  med_dist_bot[j] = median(dist2_temp[,ep_max+1])
  med_dist_dec[j] = median(dist3_temp[,ep_max+1])
  med_dist_cost[j] = median(dist4_temp[,ep_max+1])
  up_dist_split[j] = quantile(dist1_temp[,ep_max+1], probs = 0.975)
  up_dist_bot[j] = quantile(dist2_temp[,ep_max+1], probs = 0.975)
  up_dist_dec[j] = quantile(dist3_temp[,ep_max+1], probs = 0.975)
  up_dist_cost[j] = quantile(dist4_temp[,ep_max+1], probs = 0.975)
  low_dist_split[j] = quantile(dist1_temp[,ep_max+1], probs = 0.025)
  low_dist_bot[j] = quantile(dist2_temp[,ep_max+1], probs = 0.025)
  low_dist_dec[j] = quantile(dist3_temp[,ep_max+1], probs = 0.025)
  low_dist_cost[j] = quantile(dist4_temp[,ep_max+1], probs = 0.025)
  
}



df_lin = data.frame(x = 1:length(etas), 
                    m1 = med_dist_split, ub1 = up_dist_split, lb1 = low_dist_split, 
                    m2 = med_dist_bot, ub2 = up_dist_bot, lb2 = low_dist_bot,
                    m3 = med_dist_dec, ub3 = up_dist_dec, lb3 = low_dist_dec,
                    m4 = med_dist_cost, ub4 = up_dist_cost, lb4 = low_dist_cost)

#load('ICLR2021_Dataframe_linear_different_lr.Rda')
#ep_max = 100
#etas = c(1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2)
cols <- c("c4"="#3dc93d","c2"="#ff0033","c3"="#0000ff", "c1"="#000000")
my.labs = list(TeX('SplitSGD'), TeX('Half'), TeX('Decreasing'), TeX('Constant'))
ggplot(df_lin, aes(x=x)) + 
  geom_line(aes(y=m1, colour = 'c1')) +
  geom_ribbon(aes(ymin = lb1, ymax = ub1, fill = 'c1'), alpha = 0.2) +
  geom_point(data = df_lin, aes(x=x, y=m1, colour = 'c1'), shape = 17, size = 5) +
  geom_line(aes(y=m2, colour = 'c2')) +
  geom_ribbon(aes(ymin = lb2, ymax = ub2, fill = 'c2'), alpha = 0.2) +
  geom_point(data = df_lin, aes(x=x, y=m2, colour = 'c2'), shape = 16, size = 5) +
  geom_line(aes(y=m3, colour = 'c3')) +
  geom_ribbon(aes(ymin = lb3, ymax = ub3, fill = 'c3'), alpha = 0.2) +
  geom_point(data = df_lin, aes(x=x, y=m3, colour = 'c3'), shape = 15, size = 5) +
  geom_line(aes(y=m4, colour = 'c4')) +
  geom_ribbon(aes(ymin = lb4, ymax = ub4, fill = 'c4'), alpha = 0.2) +
  geom_point(data = df_lin, aes(x=x, y=m4, colour = 'c4'), shape = 18, size = 5) +
  labs(title = 'Linear Regression') +
  xlab('Initial Learning Rate') + ylab(paste0('log(loss) after ', ep_max, ' epochs')) +
  scale_colour_manual(values=cols, labels = my.labs) +
  scale_fill_manual(values=cols) +
  scale_x_continuous(breaks = 1:length(etas), labels = etas) +
  theme_bw() +
  theme(legend.position = c(0.8, 0.8), 
        legend.background = element_rect(colour = 'black'),
        legend.key.size=unit(1.5,"cm"), 
        legend.text=element_text(size=30),
        legend.title=element_blank(),
        plot.title = element_text(hjust = 0.5, size = 40),
        axis.text.x = element_text(size=25),
        axis.text.y = element_text(size=30, angle = 90),
        axis.title = element_text(size=30)) +
  guides(fill=FALSE, color = guide_legend(override.aes = list(shape = c(17,16,15,18))))


#save(df_lin, file='Neurips20_Dataframe_linear_different_lr.Rda')
#ggsave('ICLR2021_linear_regression_change_initial_lr.png', width = 25, height = 25, units = 'cm', dpi = 300)












############################################################
####################### LOGISTIC REGRESSION
############################################################

n <- 1000
d <- 20
sigma <- 1
theta_star <- 5*exp(-0.5*seq(1,d))
x <- matrix(rnorm(n * d, sd = 1), n, d)
pr = 1/(1+exp(-(x %*% theta_star)))
y <- rbinom(n, 1, pr)
getGradient <- function(th, x1, y1) { -y1*x1 + x1/(1 + exp(-dot(th, x1))) }



B = 10
nw = 20
q = 0.4
t1 = 4
theta0 = rep(0, d)
gamma = 0.5
ep_max = 100

etas = c(1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1, 3e-1)

med_dist_split = numeric(length(etas))
med_dist_bot = numeric(length(etas))
med_dist_dec = numeric(length(etas))
med_dist_cost = numeric(length(etas))
up_dist_split = numeric(length(etas))
up_dist_bot = numeric(length(etas))
up_dist_dec = numeric(length(etas))
up_dist_cost = numeric(length(etas))
low_dist_split = numeric(length(etas))
low_dist_bot = numeric(length(etas))
low_dist_dec = numeric(length(etas))
low_dist_cost = numeric(length(etas))

for(j in 1:length(etas)){
  
  dist1_temp = matrix(NA, nrow = B, ncol = ep_max+1)
  dist2_temp = matrix(NA, nrow = B, ncol = ep_max+1)
  dist3_temp = matrix(NA, nrow = B, ncol = ep_max+1)
  dist4_temp = matrix(NA, nrow = B, ncol = ep_max+1)
  eta = etas[j]
  
  for(b in 1:B){
    fit <- SplitSGD(x, y, t1 = t1, ep_max = ep_max, eta = eta, gamma = gamma,
                    w = nw, q = q, theta0 = theta0, my_grad = getGradient)
    
    bot = my_sgd_half(x, y, t1=t1, ep_max=ep_max, eta=eta, theta0 = theta0, my_grad = getGradient)
    
    s1 <- my_sgd(x, y, ep_max = ep_max, eta = eta*20, alpha = 0.5, theta0 = theta0, my_grad = getGradient)
    s2 <- my_sgd(x, y, ep_max = ep_max, eta = eta, alpha = 0, theta0 = theta0, my_grad = getGradient)
    
    mydist1 = log(loss_log(x, y, t(fit$theta[ep_max+1,])))
    mydist2 = log(loss_log(x, y, t(bot$theta[ep_max+1,])))
    mydist3 = log(loss_log(x, y, t(s1$theta[ep_max+1,])))
    mydist4 = log(loss_log(x, y, t(s2$theta[ep_max+1,])))
    
    dist1_temp[b,] = mydist1
    dist2_temp[b,] = mydist2
    dist3_temp[b,] = mydist3
    dist4_temp[b,] = mydist4
    
    print(paste0('stepsize = ', eta, ' and iter = ', b))
  }
  
  med_dist_split[j] = median(dist1_temp[,ep_max+1])
  med_dist_bot[j] = median(dist2_temp[,ep_max+1])
  med_dist_dec[j] = median(dist3_temp[,ep_max+1])
  med_dist_cost[j] = median(dist4_temp[,ep_max+1])
  up_dist_split[j] = quantile(dist1_temp[,ep_max+1], probs = 0.975)
  up_dist_bot[j] = quantile(dist2_temp[,ep_max+1], probs = 0.975)
  up_dist_dec[j] = quantile(dist3_temp[,ep_max+1], probs = 0.975)
  up_dist_cost[j] = quantile(dist4_temp[,ep_max+1], probs = 0.975)
  low_dist_split[j] = quantile(dist1_temp[,ep_max+1], probs = 0.025)
  low_dist_bot[j] = quantile(dist2_temp[,ep_max+1], probs = 0.025)
  low_dist_dec[j] = quantile(dist3_temp[,ep_max+1], probs = 0.025)
  low_dist_cost[j] = quantile(dist4_temp[,ep_max+1], probs = 0.025)
  
}



df_log = data.frame(x = 1:length(etas), 
                    m1 = med_dist_split, ub1 = up_dist_split, lb1 = low_dist_split, 
                    m2 = med_dist_bot, ub2 = up_dist_bot, lb2 = low_dist_bot,
                    m3 = med_dist_dec, ub3 = up_dist_dec, lb3 = low_dist_dec,
                    m4 = med_dist_cost, ub4 = up_dist_cost, lb4 = low_dist_cost)

#load('ICLR2021_Dataframe_logistic_different_lr.Rda')
#ep_max = 100
#etas = c(1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1, 3e-1)
cols <- c("c4"="#3dc93d","c2"="#ff0033","c3"="#0000ff", "c1"="#000000")
my.labs = list(TeX('SplitSGD'), TeX('Half'), TeX('Decreasing'), TeX('Constant'))
ggplot(df_log, aes(x=x)) + 
  geom_line(aes(y=m1, colour = 'c1')) +
  geom_ribbon(aes(ymin = lb1, ymax = ub1, fill = 'c1'), alpha = 0.2) +
  geom_point(data = df_log, aes(x=x, y=m1, colour = 'c1'), shape = 17, size = 5) +
  geom_line(aes(y=m2, colour = 'c2')) +
  geom_ribbon(aes(ymin = lb2, ymax = ub2, fill = 'c2'), alpha = 0.2) +
  geom_point(data = df_log, aes(x=x, y=m2, colour = 'c2'), shape = 16, size = 5) +
  geom_line(aes(y=m3, colour = 'c3')) +
  geom_ribbon(aes(ymin = lb3, ymax = ub3, fill = 'c3'), alpha = 0.2) +
  geom_point(data = df_log, aes(x=x, y=m3, colour = 'c3'), shape = 15, size = 5) +
  geom_line(aes(y=m4, colour = 'c4')) +
  geom_ribbon(aes(ymin = lb4, ymax = ub4, fill = 'c4'), alpha = 0.2) +
  geom_point(data = df_log, aes(x=x, y=m4, colour = 'c4'), shape = 18, size = 5) +
  labs(title = 'Logistic Regression') +
  xlab('Initial Learning Rate') + ylab(paste0('log(loss) after ', ep_max, ' epochs')) +
  scale_colour_manual(values=cols, labels = my.labs) +
  scale_fill_manual(values=cols) +
  scale_x_continuous(breaks = 1:length(etas), labels = etas) +
  theme_bw() +
  theme(legend.position = c(0.5, 0.8), 
        legend.background = element_rect(colour = 'black'),
        legend.key.size=unit(1.5,"cm"), 
        legend.text=element_text(size=30),
        legend.title=element_blank(),
        plot.title = element_text(hjust = 0.5, size = 40),
        axis.text.x = element_text(size=25),
        axis.text.y = element_text(size=30, angle = 90),
        axis.title = element_text(size=30)) +
  guides(fill=FALSE, color = guide_legend(override.aes = list(shape = c(17,16,15,18))))


#save(df_log, file='Neurips20_Dataframe_logistic_different_lr.Rda')
#ggsave('ICLR2021_logistic_regression_change_initial_lr.png', width = 25, height = 25, units = 'cm', dpi = 300)


