library(ggplot2)
library(latex2exp)
library(pracma)
library(tilting)
library(matrixStats)


stationarity_Split <- function(x, y, t1, ep_max, eta, w, q, theta0, my_grad) {
  n = nrow(x)
  ep = 0
  theta_temp = theta0
  
  while(ep < ep_max){
    for (i in 1:t1){
      idx <- sample(n, n)
      for(id in 1:n){
        theta_temp = theta_temp - eta*my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      }
      ep = ep + 1
    }
    
    theta_temp1 = theta_temp
    theta_temp2 = theta_temp
    theta_in1 = theta_temp1
    theta_in2 = theta_temp2
    Qi = c()
    idx <- sample(n, n)
    for(id in 1:n){
      if(id%%2 == 1){
        theta_temp1 = theta_temp1 - eta*my_grad(theta_temp1, x[idx[id], ], y[idx[id]])
      }
      if(id%%2 == 0){
        theta_temp2 = theta_temp2 - eta*my_grad(theta_temp2, x[idx[id], ], y[idx[id]])
      }
      
      if(id%%(n/w) == (n/w)-1){
        Qi = c(Qi, t(theta_temp1-theta_in1)%*%(theta_temp2-theta_in2))
        theta_in1 = theta_temp1
        theta_in2 = theta_temp2
      }
    }
    
    ep = ep + 1
    
    if (sum(Qi < 0) > q*length(Qi)){
      break
    }
    theta_temp = (theta_temp1 + theta_temp2)/2

  }
  ep
}


stationarity_Pflug = function(x, y, ep_max, eta, theta0, my_grad){
  n = nrow(x)
  d = ncol(x)
  ep = 0
  theta_temp = theta0
  S = 0
  g_old = rep(0, d)
  
  while(ep < ep_max){
    idx <- sample(n, n)
    for(id in 1:n){
      g = my_grad(theta_temp, x[idx[id], ], y[idx[id]])
      theta_temp = theta_temp - eta*g
      S = S + dot(g_old, g)
      g_old = g
    }
    ep = ep + 1
    
    if (S <= 0){
      break
    }
  }
  
  ep
}









####################### ####################### 
######################### Linear

n <- 1000
d <- 20
sigma <- 1
theta_star <- 5*exp(-0.5*seq(1,d))
x <- matrix(rnorm(n * d, sd = 1), n, d)
y <- as.numeric(x %*% theta_star + rnorm(n, 0, sigma))
my_grad <- function(th, x1, y1) { x1 * (dot(th, x1) - y1) }

t1 = 4
w = 20
q = 0.4

B = 100


ep_max = 1000

for(start in c('close', 'far')){
  for(lr in c('large', 'small')){
    if(start == 'far'){
      theta0 = rnorm(d, mean = 5*exp(-0.5*seq(d,1)), sd = 0.1)
    } else {
      theta0 = rnorm(d, mean = theta_star, sd = 0.1)
    }
    if(lr == 'large'){
      eta = 1e-3
    } else {
      eta = 1e-4
    }
    
    print(paste0('start = ', start, ' and eta = ', eta))
    

    ep_pflug = numeric(B)
    ep_split = numeric(B)
    
    for(b in 1:B){
      pf = stationarity_Pflug(x, y, ep_max=ep_max, eta=eta, theta0=theta0, my_grad=my_grad)
      ep_pflug[b] = pf
      print(paste0('Pflug, iter = ', b, ', epochs = ', pf))
      
      sp = stationarity_Split(x, y, t1=t1, ep_max=ep_max, eta=eta, w=w, q=q, theta0=theta0, my_grad=my_grad)
      ep_split[b] = sp
      print(paste0('Split, iter = ', b, ', epochs = ', sp))
    }
    
    
    write.csv(matrix(c(ep_pflug, ep_split), nrow=2, byrow = TRUE), 
              file =paste0("epochs_lin_lr_", lr, "_theta0_", start, ".csv"), 
              row.names = FALSE, col.names = FALSE)
  }
}



####################### ####################### 
####################### Logistic
n <- 1000
d <- 20
sigma <- 1
theta_star <- 5*exp(-0.5*seq(1,d))
x <- matrix(rnorm(n * d, sd = 1), n, d)
pr = 1/(1+exp(-(x %*% theta_star)))
y <- rbinom(n, 1, pr)
my_grad <- function(th, x1, y1) { -y1*x1 + x1/(1 + exp(-dot(th, x1))) }

t1 = 4
w = 20
q = 0.4

B = 100

ep_max = 1000


for(start in c('close', 'far')){
  for(lr in c('large', 'small')){
    if(start == 'far'){
      theta0 = rnorm(d, mean = 5*exp(-0.5*seq(d,1)), sd = 0.1)
    } else {
      theta0 = rnorm(d, mean = theta_star, sd = 0.1)
    }
    if(lr == 'large'){
      eta = 1e-2
    } else {
      eta = 1e-3
    }

    print(paste0('start = ', start, ' and eta = ', eta))
    
    ep_pflug = numeric(B)
    ep_split = numeric(B)
    
    for(b in 1:B){
      pf = stationarity_Pflug(x, y, ep_max=ep_max, eta=eta, theta0=theta0, my_grad=my_grad)
      ep_pflug[b] = pf
      print(paste0('Pflug, iter = ', b, ', epochs = ', pf))
      
      sp = stationarity_Split(x, y, t1=t1, ep_max=ep_max, eta=eta, w=w, q=q, theta0=theta0, my_grad=my_grad)
      ep_split[b] = sp
      print(paste0('Split, iter = ', b, ', epochs = ', sp))
    }
    
    write.csv(matrix(c(ep_pflug, ep_split), nrow=2, byrow = TRUE), 
              file =paste0("epochs_log_lr_", lr, "_theta0_", start, ".csv"), 
              row.names=FALSE)
    
  }
}


