library(ggplot2)
library(latex2exp)
library(pracma)
library(tilting)
library(matrixStats)

my_sgd <- function(x, y, model, t1, eta, alpha, C, theta0, nsave){
  n = nrow(x)
  d <- ncol(x)
  all_stepsize = eta
  theta = theta0
  iter = 0
  
  if (model == "lm") {
    getGradient <- function(t, x1, y1) { x1 * (dot(t, x1) - y1) }
  } 
  if (model == "log") {
    # getGradient <- function(theta, x1, y1) { -y1 * x1 / (1 + exp(y1 * sum(theta * x1))) }         # if y = +1 or -1
    getGradient <- function(t, x1, y1) { -y1*x1 + x1/(1 + exp(-dot(t, x1))) }              # if y = 0 or 1
  }
  if (model == 'svm'){
    getGradient <- function(t, x1, y1) {
      if(y1*dot(t, x1) >= 1){              
        return(t)
      }
      if(y1*dot(t, x1) < 1){              
        return(t - C*y1*x1)
      }
    }
  }
  
  theta_temp = theta0
  for (i in 1:t1) {
    idx <- sample(n, 1)
    stepsize = eta/(i^alpha)
    theta_temp = theta_temp - stepsize*getGradient(theta_temp, x[idx, ], y[idx])
    if(i%%nsave == 0){
      all_stepsize = c(all_stepsize, stepsize)
      theta = rbind(theta, theta_temp)
    }
  }
  
  out = list()
  out$theta = theta
  out$step_size = all_stepsize
  out
}



# Type I error
n = 1000
p = 10
theta_star = rep(1, p)
theta0 = rep(1, p)
sigma <- 1
x <- matrix(rnorm(n * p, sd = 1), n, p)

# y <- as.numeric(x %*% theta_star + rnorm(n, 0, sigma))
# model = 'lm'
# eta1 = 0.05
 
pr = 1/(1+exp(-(x %*% theta_star)))
y <- rbinom(n, 1, pr)
model = 'log'
eta1 = 0.5

l = 10
B = 1000
ws = c(10, 25, 100)

prop_neg = matrix(NA, nrow = B, ncol = length(ws))


for(i in 1:length(ws)){
  w = ws[i]
  for(my_B in 1:B){
    dot_prod = c()
    sgd1 = my_sgd(x, y, model = model, t1 = l*w, eta = eta1, alpha = 0, theta0 = theta0, nsave=l)
    sgd2 = my_sgd(x, y, model = model, t1 = l*w, eta = eta1, alpha = 0, theta0 = theta0, nsave=l)
    for(my_w in 1:w){
      u = sgd1$theta[my_w,] - sgd1$theta[my_w+1,]
      v = sgd2$theta[my_w,] - sgd2$theta[my_w+1,]
      dot_prod = c(dot_prod, dot(u, v)/(norm(u, type = '2')*norm(v, type = '2')))
    }
    prop_neg[my_B,i] = sum(dot_prod < 0) / w
  }
}

# qs = c(0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45)
# qs = seq(10, 49)/100
qs = seq(10, 90)/100
type_I_error = matrix(NA, nrow = length(qs), ncol = length(ws))
for(i in 1:length(qs)){
  for(j in 1:length(ws)){
  type_I_error[i,j] = sum(prop_neg[,j] < qs[i]) / B
  }
}
# type_I_error

df_err = data.frame(x = qs, w1 = type_I_error[,1], w2 = type_I_error[,2], w3 = type_I_error[,3],
                    cdf1 = pbinom(q=qs*ws[1], size=ws[1], prob=0.5), 
                    cdf2 = pbinom(q=qs*ws[2], size=ws[2], prob=0.5), 
                    cdf3 = pbinom(q=qs*ws[3], size=ws[3], prob=0.5)
)



cols <- c("c4"="#3dc93d","c2"="#ff0033","c3"="#0000ff", "c1"="#000000")
my.labs = list(TeX('$w = 10$'), TeX('$w = 25$'), TeX('$w = 100$'))
ggplot(df_err, aes(x=x)) + 
  geom_line(aes(y=w1, colour = 'c1'), size = 1.5) +
  geom_line(aes(y=cdf1, colour = 'c1'), size = 0.5) +
  geom_line(aes(y=w2, colour = 'c2'), size = 1.5) +
  geom_line(aes(y=cdf2, colour = 'c2'), size = 0.5) +
  geom_line(aes(y=w3, colour = 'c3'), size = 1.5) +
  geom_line(aes(y=cdf3, colour = 'c3'), size = 0.5) +
  # labs(title = TeX('Linear Regression with $\\eta = 0.05$')) +
  labs(title = TeX('Logistic Regression with $\\eta = 0.5$')) +
  xlab('q') +
  ylab('Probability of type I error') +
  scale_colour_manual(values=cols, labels = my.labs) +
  scale_fill_manual(values=cols) +
  # scale_y_continuous(breaks = c(0, 0.1, 0.2, 0.3, 0.4, 0.5), labels = c(0, 0.1, 0.2, 0.3, 0.4, 0.5)) +
  # scale_x_continuous(breaks = 1:length(qs), labels = qs) +
  theme_bw() +
  theme(legend.position = c(0.22, 0.8), 
        legend.background = element_rect(colour = 'black'),
        legend.key.size=unit(1.5,"cm"), 
        legend.text=element_text(size=25),
        legend.title=element_blank(),
        plot.title = element_text(hjust = 0.5, size = 25),
        axis.text.x = element_text(size=18),
        axis.text.y = element_text(size=18),
        axis.title.y = element_text(size=20),
        axis.title.x = element_text(size=20)) +
  guides(fill=FALSE, color = guide_legend(override.aes = list(shape = c(18,17,16))))







# cols <- c("c4"="#3dc93d","c2"="#ff0033","c3"="#0000ff", "c1"="#000000")
# my.labs = list(TeX('$w = 10$'), TeX('$w = 25$'), TeX('$w = 100$'))
# ggplot(df_err, aes(x=x)) + 
#   geom_line(aes(y=w1 - cdf1, colour = 'c1'), size = 1.5) +
#   # geom_point(data = df_err[c(1,3),], aes(x=x, y=w1, colour = 'c1'), shape = 18, size = 5) +
#   geom_line(aes(y=w2 - cdf2, colour = 'c2'), size = 1.5) +
#   # geom_point(data = df_err[c(2,4),], aes(x=x, y=w2, colour = 'c2'), shape = 17, size = 5) +
#   geom_line(aes(y=w3 - cdf3, colour = 'c3'), size = 1.5) +
#   # geom_point(data = df_err[c(1,3),], aes(x=x, y=w3, colour = 'c3'), shape = 16, size = 5) +
#   labs(title = 'Linear Regression') +
#   xlab('q') +
#   ylab('Probability of type I error') +
#   scale_colour_manual(values=cols, labels = my.labs) +
#   scale_fill_manual(values=cols) +
#   # scale_y_continuous(breaks = c(0, 0.1, 0.2, 0.3, 0.4, 0.5), labels = c(0, 0.1, 0.2, 0.3, 0.4, 0.5)) +
#   #  scale_x_continuous(breaks = 1:length(qs), labels = qs) +
#   theme_bw() +
#   theme(legend.position = c(0.22, 0.8), 
#         legend.background = element_rect(colour = 'black'),
#         legend.key.size=unit(1.5,"cm"), 
#         legend.text=element_text(size=30),
#         legend.title=element_blank(),
#         plot.title = element_text(hjust = 0.5, size = 30),
#         axis.text.x = element_text(size=25),
#         axis.text.y = element_text(size=25),
#         axis.title.y = element_text(size=30),
#         axis.title.x = element_text(size=30)) +
#   guides(fill=FALSE, color = guide_legend(override.aes = list(shape = c(18,17,16))))
