library(ggplot2)
library(latex2exp)
library(pracma)
library(tilting)
library(matrixStats)


my_sgd <- function(x, y, ep_max, eta, alpha, theta0, my_grad){
  n = nrow(x)
  d <- ncol(x)
  all_stepsize = eta
  theta = theta0
  iter = 1
  
  theta_temp = theta0
  for (i in 1:ep_max){
    idx <- sample(n, n)
    for(id in 1:n){
      stepsize = eta/(iter^alpha)
      theta_temp = theta_temp - stepsize*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      iter = iter + 1
    }
    all_stepsize = c(all_stepsize, stepsize)
    theta = rbind(theta, theta_temp)
  }
  
  out = list()
  out$theta = theta
  out$step_size = all_stepsize
  out
}


loss_lm = function(x, y, theta){
  n = length(y)
  return((colSums((y - x%*%t(theta))^2))/n)
}


loss_log = function(x, y, theta){
  n = length(y)
  return((colSums(-y*(x%*%t(theta)) + log(1+exp(x%*%t(theta)))))/n)
}



#################### Linear Regression
n <- 1000
d <- 20
sigma <- 1
theta_star <- 5*exp(-0.5*seq(1,d))
x <- matrix(rnorm(n * d, sd = 1), n, d)
y <- as.numeric(x %*% theta_star + rnorm(n, 0, sigma))
getGradient <- function(th, x1, y1) { x1 * (dot(th, x1) - y1) }

B = 2

#start = 'close'
start = 'far'

if(start == 'far'){
  theta0 = rnorm(d, mean = 5*exp(-0.5*seq(d,1)), sd = 0.1)
} else {
  theta0 = rnorm(d, mean = theta_star, sd = 0.1)
}

ep_max = 200
eta = 1e-3

dist = matrix(NA, nrow = B, ncol = ep_max+1)

for(b in 1:B){
  sgd <- my_sgd(x, y, ep_max = ep_max, eta = eta, alpha = 0, theta0 = theta0, my_grad = getGradient)
  dist[b,] = log(loss_lm(x, y, sgd$theta))

  print(paste0('iter = ', b))
}

med = colMeans(dist)
up = colQuantiles(dist, probs = 0.975)
low = colQuantiles(dist, probs = 0.025)

skip = 5
#ep_max = 50
df_lin = data.frame(x = skip:ep_max, 
                    m4 = med[(skip+1):(ep_max+1)], 
                    ub4 = up[(skip+1):(ep_max+1)], 
                    lb4 = low[(skip+1):(ep_max+1)])


cols <- c("c1"="#000000")
my.labs = list(TeX('Constant'))
ggplot(df_lin, aes(x=x)) + 
  geom_line(aes(y=m4, colour = 'c1')) +
  geom_ribbon(aes(ymin = lb4, ymax = ub4, fill = 'c1'), alpha = 0.2) +
  labs(title = paste0('Linear Regression with lr = ', eta, ' and theta0 ', start)) +
  xlab('Epochs') + ylab('log(loss)') +
  scale_colour_manual(values=cols, labels = my.labs) +
  scale_fill_manual(values=cols) +
  theme_bw() +
  theme(legend.position = 'none', 
        plot.title = element_text(hjust = 0.5, size = 18),
        axis.text.x = element_text(size=15),
        axis.text.y = element_text(size=15, angle = 90),
        axis.title = element_text(size=15)) +
  guides(fill=FALSE)











##################### Logistic Regression
n <- 1000
d <- 20
sigma <- 1
theta_star <- 5*exp(-0.5*seq(1,d))
x <- matrix(rnorm(n * d, sd = 1), n, d)
pr = 1/(1+exp(-(x %*% theta_star)))
y <- rbinom(n, 1, pr)
getGradient <- function(th, x1, y1) { -y1*x1 + x1/(1 + exp(-dot(th, x1))) }

B = 3
start = 'close'
#start = 'far'

if(start == 'far'){
  theta0 = rnorm(d, mean = 5*exp(-0.5*seq(d,1)), sd = 0.1)
} else {
  theta0 = rnorm(d, mean = theta_star, sd = 0.1)
}

ep_max = 100
eta = 1e-3

dist = matrix(NA, nrow = B, ncol = ep_max+1)

for(b in 1:B){
  sgd <- my_sgd(x, y, ep_max = ep_max, eta = eta, alpha = 0, theta0 = theta0, my_grad = getGradient)
  dist[b,] = log(loss_log(x, y, sgd$theta))
  
  print(paste0('iter = ', b))
}

med = colMeans(dist)
up = colQuantiles(dist, probs = 0.975)
low = colQuantiles(dist, probs = 0.025)

skip = 0
df_log = data.frame(x = skip:ep_max, 
                    m4 = med[(skip+1):(ep_max+1)], 
                    ub4 = up[(skip+1):(ep_max+1)], 
                    lb4 = low[(skip+1):(ep_max+1)])


cols <- c("c1"="#000000")
my.labs = list(TeX('Constant'))
ggplot(df_log, aes(x=x)) + 
  geom_line(aes(y=m4, colour = 'c1')) +
  geom_ribbon(aes(ymin = lb4, ymax = ub4, fill = 'c1'), alpha = 0.2) +
  labs(title = paste0('Logistic Regression with lr = ', eta, ' and theta0 ', start)) +
  xlab('Epochs') + ylab('log(loss)') +
  scale_colour_manual(values=cols, labels = my.labs) +
  scale_fill_manual(values=cols) +
  theme_bw() +
  theme(legend.position = 'none', 
        plot.title = element_text(hjust = 0.5, size = 18),
        axis.text.x = element_text(size=15),
        axis.text.y = element_text(size=15, angle = 90),
        axis.title = element_text(size=15)) +
  guides(fill=FALSE)



