library(ggplot2)
library(tidyverse)
library(gdata)
library(gridExtra)
library(MASS)
library(latex2exp)
library(gdata)
library(hrbrthemes)
library(patchwork)
library(xtable)
library(data.table)
# investigating the difference between the variational posterior mean
# and the full posterior mean
#### HELPER FUNCTIONS ####
expand.grid.extra = function(params_1, params_2){
  #expand.grid, but adds each case of params_2 to each case of params_1
  final_params = matrix(nrow = dim(params_1)[1]*dim(params_2)[1],
                        ncol = dim(params_1)[2] + dim(params_2)[2])
  for(i in 1:dim(params_1)[1]){
    for(j in 1:dim(params_2)[1]){
      param_ix = (i-1)*dim(params_2)[1] + j
      # print(c(as.matrix(params_1[i,]), as.matrix(params_2[j,])))
      final_params[param_ix,] = c(as.matrix(params_1[i,]), as.matrix(params_2[j,]))
    }
  }
  colnames(final_params) = c(colnames(params_1), colnames(params_2))
  return(final_params)
}
#### KERNELS ####
bm_kern = function(x, y, cn = 1){
  #rescaled Brownian motion kernel
  return(cn*min(x, y))
}
mat_kern_3_2 = function(x, y, cn = 1){
  r = sqrt(sum((x - y)^2))
  return( (1+sqrt(3)*r/cn)*exp(-sqrt(3)*r/cn) )
}
mat_kern_5_2 = function(x, y, cn = 1){
  r = sqrt(sum((x - y)^2))
  return( (1+sqrt(5)*r/cn+ 5*r^2/(3*cn^2))*exp(-sqrt(5)*r/cn) )
}
mat_kern_1_2 = function(x, y, cn = 1){
  r = sqrt(sum((x - y)^2))
  return(exp(-r/cn))
}
se_kern = function(x, y, cn = 1){
  # provide cn as the lengthscale
  return(exp(-(x-y)^2 / cn^2))
}

se_kern_multi_d = function(x, y, cn = 1){
  # provide cn as the lengthscale
  r = sqrt(sum((x - y)^2))
  return(exp(-r^2 / cn^2))
}

lin_kern = function(x, y){
  return( 1+ x*y ) 
}

#### BASIS FUNCTIONS ####
basis_functions_x = function(K, xs){
  # K is the (even) number of basis functions to compute 
  args = seq(1, K/2, by=1)*2*pi
  args = as.matrix(xs)%*%t(as.matrix(args))
  # L is the size of the range discretisation
  cosines = sqrt(2)*cos(args)
  sines = sqrt(2)*sin(args)
  ones = as.matrix(rep(1,length(xs)))
  X = cbind(ones, t(interleave(t(cosines), t(sines))))
  return(X)
}
compute_fn_from_coefs_x = function(coefs, xs){
  K = length(coefs) - 1
  basis_fn_vals = basis_functions_x(K,xs)
  fn_vals = basis_fn_vals%*%coefs
  return(fn_vals)
}
f0_xs = function(xs, alpha){
  if(alpha <= 1){
    return(abs(xs - 1/2)^alpha)
  }else if(alpha == 2){
   return( sign(xs - 1/2) * abs(xs - 1/2)^alpha_2   )
  }
}

# f0_xs = function(xs, alpha){
#   if(alpha == 1){
#     return(sign(xs - 1/2)*(abs(xs - 1/2)^2) )
#   }else if(alpha == 2){
#     return( abs(xs - 1/2)^3)
#   }
# }

f0_xs_multi_d = function(Xs, alpha){
  if(!is.matrix(Xs)){
    Xs = matrix(Xs, nrow = 1)
  }
  apply(Xs, 1, function(x) sqrt(sum(x^2))^alpha )
}

#### VARIATIONAL POSTERIOR FUNCTIONS ####

vp_mean_variance = function(x_, m_, svd_, k_, xns_, y_, cn_ = 1, sigma=1){
  eta_ks = 1/(svd_$d[1:m_] + sigma^2)
  V_m = svd_$u[,1:m_]
  A = V_m %*% diag(eta_ks) %*% t(V_m)
  k_n_x = sapply(xns_, function(y) k_(y, x_, cn = cn_))
  mean_x = t(k_n_x) %*% A %*% y_
  var_x = k_(x_, x_, cn = cn_) - t(k_n_x) %*% A %*% k_n_x
  return(c(mean_x, var_x))
}

vp_mean_variance_multiple_m = function(x_, m_s, svd_, k_, xns_, y_, cn_ = 1, sigma=1){
  ret_mat = matrix(ncol = 2, nrow = length(m_s))
  k_n_x = sapply(xns_, function(y) k_(y, x_, cn = cn_))
  for(ix in 1:length(m_s)){
    m_ = m_s[ix]
    eta_ks = 1/(svd_$d[1:m_] + sigma^2)
    V_m = svd_$u[,1:m_]
    A = V_m %*% diag(eta_ks) %*% t(V_m)
    mean_x = t(k_n_x) %*% A %*% y_
    var_x = k_(x_, x_, cn = cn_) - t(k_n_x) %*% A %*% k_n_x
    ret_mat[ix,1] = mean_x
    ret_mat[ix,2] = var_x
  }
  return(ret_mat)  
}

vp_mean_variance_multiple_m_multi_d = function(x_, m_s, svd_, k_, xns_, y_, cn_ = 1, sigma = 1){
  ret_mat = matrix(ncol = 2, nrow = length(m_s))
  k_n_x = apply(xns_, 1, function(x) k_(x, x_, cn = cn_))
  for(ix in 1:length(m_s)){
    m_ = m_s[ix]
    eta_ks = 1/(svd_$d[1:m_] + sigma^2)
    V_m = svd_$u[,1:m_]
    A = V_m %*% diag(eta_ks) %*% t(V_m)
    mean_x = t(k_n_x) %*% A %*% y_
    var_x = k_(x_, x_, cn = cn_) - t(k_n_x) %*% A %*% k_n_x
    ret_mat[ix,1] = mean_x
    ret_mat[ix,2] = var_x
  }
  return(ret_mat)  
}

sample_variational_CS = function(gamma_, m_, x0_, svd_, alpha_, cn_ = 1, rand_design=FALSE){
  n = length(svd_$d)
  if(rand_design){
    xns = runif(n, 0, 1) 
  }else{
    xns = c(1:n)/(n+0.5)  
  }
  K = 401
  #coeffs of f0
  f0_ks = (1:K)^(-1/2 - alpha)
  f0 = compute_fn_from_coefs_x(f0_ks, xs = xns)
  y = f0 + rnorm(n, mean = 0, sd = 1)
  
  sigma_est = maximise_sigma_known_svd(y, svd_)
  
  vp_fit = vp_mean_variance(x_ = x0_, m_=m_, svd_=svd_, k_=kern, xns_=xns,
                            y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[1]
  vp_var = vp_fit[2]
  
  z_gamma = qnorm((1+gamma)/2)
  vp_CS = c(vp_mean - sqrt(vp_var)*z_gamma, vp_mean+sqrt(vp_var)*z_gamma)
  return(vp_CS)
  # f0_x0 = compute_fn_from_coefs_x(f0_ks, xs = x0_)
  # 
  # # see if f0_x0 is in vp_CS
  # hit = vp_CS[1] <= f0_x0 & f0_x0 <= vp_CS[2]
  # return()
}

sample_variational_CS_multiple_m = function(gamma_, m_s, x0_, svd_, alpha_, cn_ = 1, kern = bm_kern,
                                             xns, noise = 'gaussian'){
  n = dim(svd_$u)[1]
  # if(rand_design){
  #   xns = runif(n, 0, 1) 
  # }else{
  #   xns = c(1:n)/(n+0.5)  
  # }
  # K = 401
  #coeffs of f0
  # f0_ks = (1:K)^(-1/2 - alpha_)
  # f0 = compute_fn_from_coefs_x(f0_ks, xs = xns)
  f0 = f0_xs(xs = xns, alpha = alpha_)
  if(noise == 'gaussian'){
    eps = rnorm(n)
  }else if(noise == 'laplace'){
    eps = rlaplace(n)
  }
  y = f0 + eps
  sigma_est = maximise_sigma_known_svd(y, svd_)
  cat('Estimated sigma: ', sigma_est, '\n')
  vp_fit = vp_mean_variance_multiple_m(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                       y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  z_gamma = qnorm((1+gamma_)/2)
  vp_CS = cbind(vp_mean - sqrt(vp_var)*z_gamma, vp_mean+sqrt(vp_var)*z_gamma)
  return(vp_CS)
}

compute_NLPD = function(gamma_, m_s, x0_, svd_, alpha_, cn_ = 1, kern = bm_kern,
                                            xns, noise = 'gaussian'){
  n = dim(svd_$u)[1]
  f0 = f0_xs(xs = xns, alpha = alpha_)
  if(noise == 'gaussian'){
    eps = rnorm(n)
  }else if(noise == 'laplace'){
    eps = rlaplace(n)
  }
  y = f0 + eps
  sigma_est = maximise_sigma_known_svd(y, svd_)
  vp_fit = vp_mean_variance_multiple_m(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                       y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  NLPDs = -log(dnorm(0, vp_mean, sd = sqrt(vp_var)))
  return(NLPDs)
}


sample_variational_CS_multiple_m_random_design = function(gamma_, m_s, x0_, n, alpha_, cn_ = 1, kern = bm_kern, noise = 'gaussian'){
  xns = runif(n)
  f0 = f0_xs(xs = xns, alpha = alpha_)
  if(noise == 'gaussian'){
    eps = rnorm(n)
  }else if(noise == 'laplace'){
    eps = rlaplace(n)
  }
  y = f0 + eps
  K_nn = sapply(xns, function(x) sapply(xns, function(y) kern(x, y, cn = cn_)))
  svd_ = svd(K_nn)
  sigma_est = maximise_sigma_known_svd(y, svd_)
  vp_fit = vp_mean_variance_multiple_m(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                       y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  z_gamma = qnorm((1+gamma_)/2)
  vp_CS = cbind(vp_mean - sqrt(vp_var)*z_gamma, vp_mean+sqrt(vp_var)*z_gamma)
  return(vp_CS)
}

compute_NLPD_random_design = function(gamma_, m_s, x0_, n, alpha_, cn_ = 1, kern = bm_kern, noise = 'gaussian'){
  xns = runif(n)
  f0 = f0_xs(xs = xns, alpha = alpha_)
  if(noise == 'gaussian'){
    eps = rnorm(n)
  }else if(noise == 'laplace'){
    eps = rlaplace(n)
  }
  y = f0 + eps
  K_nn = sapply(xns, function(x) sapply(xns, function(y) kern(x, y, cn = cn_)))
  svd_ = svd(K_nn)
  sigma_est = maximise_sigma_known_svd(y, svd_)
  vp_fit = vp_mean_variance_multiple_m(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                       y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  z_gamma = qnorm((1+gamma_)/2)
  NLPDs = -log(dnorm(0, vp_mean, sd = sqrt(vp_var)))
  return(NLPDs)
}

sample_variational_CS_multiple_m_multi_d = function(gamma_, m_s, x0_, svd_, alpha_, cn_ = 1, kern = bm_kern,
                                            xns){
  n = dim(svd_$u)[1]
  # K = 401
  # #coeffs of f0
  # f0_ks = (1:K)^(-1/2 - alpha_)
  # f0 = apply(xns, 1, function(x) sum(sapply(x, function(xi) compute_fn_from_coefs_x(f0_ks, xs = xi))))
  f0 = f0_xs_multi_d(Xs = xns, alpha = alpha_)
  y = f0 + rnorm(n, mean = 0, sd = 1)
  # sigma_est = maximise_sigma_known_svd(y, svd_)
  sigma_est = 1
  
  vp_fit = vp_mean_variance_multiple_m_multi_d(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                       y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  z_gamma = qnorm((1+gamma_)/2)
  vp_CS = cbind(vp_mean - sqrt(vp_var)*z_gamma, vp_mean+sqrt(vp_var)*z_gamma)
  return(vp_CS)
}
sample_variational_CS_multiple_m_multi_d_adaptive = function(gamma_, m_s, x0_, alpha_, kern = mat_kern_1_2,
                                                    xns){
  n = dim(xns)[1]
  # K = 401
  # #coeffs of f0
  # f0_ks = (1:K)^(-1/2 - alpha_)
  # f0 = apply(xns, 1, function(x) sum(sapply(x, function(xi) compute_fn_from_coefs_x(f0_ks, xs = xi))))
  f0 = f0_xs_multi_d(Xs = xns, alpha = alpha_)
  y = f0 + rnorm(n, mean = 0, sd = 1)
  sigma_est = maximise_sigma_known_svd(y, svd_)
  cn_est = maximise_cn(y, xns, kern = kern)
  # cn_est = n^{-1/(1+2*alpha_)}
  
  K_nn = apply(xns, 1, function(x) apply(xns, 1, function(y) kern(x, y, cn)))
  svd_ = svd(K_nn)
  
  vp_fit = vp_mean_variance_multiple_m_multi_d(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                               y_=y, cn_ = cn_est, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  z_gamma = qnorm((1+gamma_)/2)
  vp_CS = cbind(vp_mean - sqrt(vp_var)*z_gamma, vp_mean+sqrt(vp_var)*z_gamma)
  return(vp_CS)
}

compute_NLPD_multi_d = function(gamma_, m_s, x0_, svd_, alpha_, cn_ = 1, kern = bm_kern,
                                                    xns){
  n = dim(svd_$u)[1]
  # K = 401
  # #coeffs of f0
  # f0_ks = (1:K)^(-1/2 - alpha_)
  # f0 = apply(xns, 1, function(x) sum(sapply(x, function(xi) compute_fn_from_coefs_x(f0_ks, xs = xi))))
  f0 = f0_xs_multi_d(Xs = xns, alpha = alpha_)
  y = f0 + rnorm(n, mean = 0, sd = 1)
  sigma_est = maximise_sigma_known_svd(y, svd_)
  
  vp_fit = vp_mean_variance_multiple_m_multi_d(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                               y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  NLPDs = -log(dnorm(0, vp_mean, sd = sqrt(vp_var)))
  return(NLPDs)
}

sample_correlated_inputs = function(n, p, rho){
  print(rho)
  Sigma = matrix(rho, nrow = p, ncol = p)
  diag(Sigma) = rep(1,p)
  U = chol(Sigma)
  normal_sample = matrix(rnorm(n*p), nrow = n, ncol = p)
  X = normal_sample%*%U
  return(X)
}
sample_variational_CS_multiple_m_multi_d_random_design = function(gamma_, m_s, x0_, n, cor, alpha_, cn_ = 1, kern = mat_kern_1_2){
  d = length(x0_)
  if(is.na(cor)){
    print('Uniform Design')
    xns = matrix(runif(n*d, -0.5, 0.5) , nrow = n, ncol = d)
  }else{
    print('Gaussian Design')
    xns = sample_correlated_inputs(n, p = d, rho = cor)  
  }
  
  f0 = f0_xs_multi_d(Xs = xns, alpha = alpha_)
  y = f0 + rnorm(n, mean = 0, sd = 1)
  K_nn = apply(xns, 1, function(x) apply(xns, 1, function(y) kern(x, y, cn = cn)))
  svd_ = svd(K_nn)
  sigma_est = maximise_sigma_known_svd(y, svd_)
  cat('Estimated sigma: ', sigma_est, '\n')
  
  vp_fit = vp_mean_variance_multiple_m_multi_d(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                               y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  z_gamma = qnorm((1+gamma_)/2)
  vp_CS = cbind(vp_mean - sqrt(vp_var)*z_gamma, vp_mean+sqrt(vp_var)*z_gamma)
  return(vp_CS)
}

compute_NLPD_multi_d_random_design = function(gamma_, m_s, x0_, n, cor, alpha_, cn_ = 1, kern = mat_kern_1_2){
  d = length(x0_)
  if(is.na(cor)){
    print('Uniform Design')
    xns = matrix(runif(n*d, -0.5, 0.5) , nrow = n, ncol = d)
  }else{
    print('Gaussian Design')
    xns = sample_correlated_inputs(n, p = d, rho = cor)  
  }
  
  f0 = f0_xs_multi_d(Xs = xns, alpha = alpha_)
  y = f0 + rnorm(n, mean = 0, sd = 1)
  K_nn = apply(xns, 1, function(x) apply(xns, 1, function(y) kern(x, y, cn = cn)))
  svd_ = svd(K_nn)
  sigma_est = maximise_sigma_known_svd(y, svd_)
  cat('Estimated sigma: ', sigma_est, '\n')
  
  vp_fit = vp_mean_variance_multiple_m_multi_d(x_=x0_, m_s=m_s, svd_=svd_, k_=kern, xns_=xns,
                                               y_=y, cn_ = cn_, sigma = sigma_est)
  vp_mean = vp_fit[,1]
  vp_var = vp_fit[,2]
  
  z_gamma = qnorm((1+gamma_)/2)
  NLPD = -log(dnorm(0, vp_mean, sd = sqrt(vp_var)))
  return(NLPD)
}



estimate_cov_len_bias = function(gamma_, m_, x0_, svd_, alpha_, num_replicates = 1000, cn_ = 1, rand_design = FALSE){
  cat('m_: ',m_,'\n')
  K = 401
  f0_ks = (1:K)^(-1/2 - alpha_)
  f0_x0 = as.numeric(compute_fn_from_coefs_x(f0_ks, xs = x0))
  test_intervals = t(replicate(num_replicates,
                               sample_variational_CS(gamma_, m_ = m_,
                                                     x0_ = x0_, svd_ = svd_,
                                                     alpha_ = alpha_, cn_ = cn_)))
  
  cov = mean(test_intervals[,1] <= f0_x0 & test_intervals[,2] >= f0_x0)
  len = mean(test_intervals[,2] - test_intervals[,1])
  bias = mean((test_intervals[,1]+test_intervals[,2])/2 - f0_x0)
  
  return(c(cov, len, bias))
}

estimate_cov_len_bias_multiple_m = function(gamma_, m_s, x0_, svd_, alpha_, num_replicates = 1000, cn_ = 1,
                                            kern = bm_kern, xns, noise = 'gaussian'){
  n = dim(svd_$u)[1]
  # K = 401
  # f0_ks = (1:K)^(-1/2 - alpha_)
  # f0_x0 = as.numeric(compute_fn_from_coefs_x(f0_ks, xs = x0_))
  f0_x0 = f0_xs(xs = x0_, alpha = alpha_)
  xns = c(1:n)/(n+1/2)
  test_intervals = replicate(num_replicates,
                               sample_variational_CS_multiple_m(gamma_, m_s = m_s,
                                                                x0_ = x0_, svd_ = svd_,
                                                                alpha_ = alpha_, cn_ = cn_, kern = kern,
                                                                xns = xns, noise = noise))
  NLPDs = replicate(num_replicates,
                             compute_NLPD(gamma_, m_s = m_s,
                                          x0_ = x0_, svd_ = svd_,
                                          alpha_ = alpha_, cn_ = cn_, kern = kern,
                                          xns = xns, noise = noise))
  
  cat('alpha: ', alpha_, '\n')
  cat('f_0(x_0): ', f0_x0, '\n')
  ret_mat = matrix(nrow = length(m_s), ncol = 5)
  for(ix in 1:length(m_s)){
    m_intervals = t(test_intervals[ix,,])
    ret_mat[ix,1] = mean(m_intervals[,1] <= f0_x0 & m_intervals[,2] >= f0_x0)
    ret_mat[ix,2] = mean(m_intervals[,2] - m_intervals[,1])
    ret_mat[ix,3] = sqrt(mean(((m_intervals[,1]+m_intervals[,2])/2 - f0_x0)^2))
    ret_mat[ix, 4] = mean(NLPDs[ix,])
    ret_mat[ix, 5] = sd(NLPDs[ix,])
  }
  return(ret_mat)
}

estimate_cov_len_bias_multiple_m_random_design = function(gamma_, m_s, x0_, n, alpha_, num_replicates = 1000, cn_ = 1,
                                            kern = bm_kern, noise = 'gaussian'){
  f0_x0 = f0_xs(xs = x0_, alpha = alpha_)
  test_intervals = replicate(num_replicates,
                             sample_variational_CS_multiple_m_random_design(gamma_, m_s = m_s,
                                                              x0_ = x0_, n = n,
                                                              alpha_ = alpha_, cn_ = cn_, kern = kern, noise = noise))
  NLPDs = replicate(num_replicates,
                    compute_NLPD_random_design(gamma_, m_s = m_s,
                                               x0_ = x0_, n = n,
                                               alpha_ = alpha_, cn_ = cn_, kern = kern, noise = noise))
  
  cat('alpha: ', alpha_, '\n')
  cat('f_0(x_0): ', f0_x0, '\n')
  ret_mat = matrix(nrow = length(m_s), ncol = 6)
  for(ix in 1:length(m_s)){
    m_intervals = t(test_intervals[ix,,])
    ret_mat[ix,1] = mean(m_intervals[,1] <= f0_x0 & m_intervals[,2] >= f0_x0)
    ret_mat[ix,2] = mean(m_intervals[,2] - m_intervals[,1])
    ret_mat[ix,3] = sd(m_intervals[,2] - m_intervals[,1])
    ret_mat[ix,4] = sqrt(mean(((m_intervals[,1]+m_intervals[,2])/2 - f0_x0)^2))
    ret_mat[ix,5] = mean(NLPDs[ix,])
    ret_mat[ix,6] = sd(NLPDs[ix,])
  }
  return(ret_mat)
}

estimate_cov_len_bias_multiple_m_multi_d = function(gamma_, m_s, x0_, svd_, alpha_, num_replicates = 1000, cn_ = 1,
                                            kern = mat_kern_1_2, xns){
  n = dim(svd_$u)[1]
  # K = 401
  # f0_ks = (1:K)^(-1/2 - alpha_)
  # f0_x0 = sum(sapply(x0_, function(x) compute_fn_from_coefs_x(f0_ks, xs = x)))
  f0_x0 = f0_xs_multi_d(Xs = x0_, alpha = alpha_)
  
  # TODO: ESTIMATE CN HERE
  
  test_intervals = replicate(num_replicates,
                             sample_variational_CS_multiple_m_multi_d(gamma_, m_s = m_s,
                                                              x0_ = x0_, svd_ = svd_,
                                                              alpha_ = alpha_, cn_ = cn_, kern = kern,
                                                              xns = xns))
  
  NLPDs = replicate(num_replicates,
                             compute_NLPD_multi_d(gamma_, m_s = m_s,
                                                  x0_ = x0_, svd_ = svd_,
                                                  alpha_ = alpha_, cn_ = cn_, kern = kern,
                                                  xns = xns))
  
  ret_mat = matrix(nrow = length(m_s), ncol = 6)
  for(ix in 1:length(m_s)){
    m_intervals = t(test_intervals[ix,,])
    ret_mat[ix,1] = mean(m_intervals[,1] <= f0_x0 & m_intervals[,2] >= f0_x0)
    ret_mat[ix,2] = mean(m_intervals[,2] - m_intervals[,1])
    ret_mat[ix,3] = sd(m_intervals[,2] - m_intervals[,1])
    ret_mat[ix,4] = sqrt(mean(((m_intervals[,1]+m_intervals[,2])/2 - f0_x0)^2))
    ret_mat[ix,5] = mean(NLPDs[ix,])
    ret_mat[ix,6] = sd(NLPDs[ix,])
  }
  return(ret_mat)
}
estimate_cov_len_bias_multiple_m_multi_d_emp_bayes = function(gamma_, m_s, x0_, svd_, alpha_, num_replicates = 1000, cn_ = 1,
                                                    kern = mat_kern_1_2, xns){
  n = dim(svd_$u)[1]
  f0_x0 = f0_xs_multi_d(Xs = x0_, alpha = alpha_)
  y = f0_x0 + rnorm(n)
  print('Fitting cn')
  cn = maximise_cn(y, xns, kern)
  cat('maximum cn: ', cn, 'gamma: ', -(1 + log(cn, n))/(2*log(cn, n)))
  gamma_hat = -(1 + log(cn, n))/(2*log(cn, n))
  # cn = n^{gamma_hat - 1}
  
  
  K_nn = apply(xns, 1, function(x) apply(xns, 1, function(y) kern(x, y, cn = cn)))
  svd = svd(K_nn)
  
  test_intervals = replicate(num_replicates,
                             sample_variational_CS_multiple_m_multi_d(gamma_, m_s = m_s,
                                                                      x0_ = x0_, svd_ = svd,
                                                                      alpha_ = alpha_, cn_ = cn, kern = kern,
                                                                      xns = xns))
  NLPDs = replicate(num_replicates,
                    compute_NLPD_multi_d(gamma_, m_s = m_s,
                                         x0_ = x0_, svd_ = svd_,
                                         alpha_ = alpha_, cn_ = cn, kern = kern,
                                         xns = xns))
  
  ret_mat = matrix(nrow = length(m_s), ncol = 5)
  for(ix in 1:length(m_s)){
    m_intervals = t(test_intervals[ix,,])
    ret_mat[ix,1] = mean(m_intervals[,1] <= f0_x0 & m_intervals[,2] >= f0_x0)
    ret_mat[ix,2] = mean(m_intervals[,2] - m_intervals[,1])
    ret_mat[ix,3] = sqrt(mean(((m_intervals[,1]+m_intervals[,2])/2 - f0_x0)^2))
    ret_mat[ix,4] = mean(NLPDs[ix,])
    ret_mat[ix,5] = sd(NLPDs[ix,])
  }
  cat('maximum cn: ', cn, 'gamma: ', -(1 + log(cn, n))/(2*log(cn, n)))
  return(ret_mat)
}


estimate_cov_len_bias_multiple_m_multi_d_random_design = function(gamma_, m_s, x0_, n, cor, alpha_, num_replicates = 1000, cn_ = 1,
                                                    kern = mat_kern_1_2){
  f0_x0 = f0_xs_multi_d(Xs = x0_, alpha = alpha_)
  cat('alpha: ', alpha_, '\n')
  cat('f_0(x_0): ', f0_x0, '\n')
  print(f0_x0)
  test_intervals = replicate(num_replicates,
                             sample_variational_CS_multiple_m_multi_d_random_design(gamma_, m_s = m_s,
                                                                      x0_ = x0_, n, cor,
                                                                      alpha_ = alpha_, cn_ = cn_, kern = kern))
  NLPDs = replicate(num_replicates,
                             compute_NLPD_multi_d_random_design(gamma_, m_s = m_s,
                                                                x0_ = x0_, n, cor,
                                                                alpha_ = alpha_, cn_ = cn_, kern = kern))
  
  ret_mat = matrix(nrow = length(m_s), ncol = 6)
  for(ix in 1:length(m_s)){
    m_intervals = t(test_intervals[ix,,])
    ret_mat[ix,1] = mean(m_intervals[,1] <= f0_x0 & m_intervals[,2] >= f0_x0)
    ret_mat[ix,2] = mean(m_intervals[,2] - m_intervals[,1])
    ret_mat[ix,3] = sd(m_intervals[,2] - m_intervals[,1])
    ret_mat[ix,4] = sqrt(mean(((m_intervals[,1]+m_intervals[,2])/2 - f0_x0)^2))
    ret_mat[ix,5] = mean(NLPDs[ix,])
    ret_mat[ix,6] = sd(NLPDs[ix,])
  }
  return(ret_mat)
}

r_m = function(x_, xns_, m_, svd_, k_){
  eta_ks = 1/(svd_$d[1:m_] + 1)
  V_m = svd_$u[,1:m_]
  A = V_m %*% diag(eta_ks) %*% t(V_m)
  k_n_x = sapply(xns_, function(y) k_(y, x_))
  return(A%*%k_n_x)
}

#### LIKELIHOOD FUNCTIONS
lmlikelihood_known_svd = function(y, sigma, svd_){
  n = length(y)
  K_n_inv = svd_$v %*% diag(1/(sigma^2 + svd_$d)) %*% t(svd_$v)
  return( -(1/2)*t(y) %*% K_n_inv %*% y -(1/2)*sum(log(svd_$d + sigma^2)))
}
maximise_sigma_known_svd = function(y, svd_){
  sigmas = seq(0.1, 2, by = 0.1)
  results = sapply(sigmas, function(s) lmlikelihood_known_svd(y, s, svd_))
  max_sigma = sigmas[which.max(results)]
  return(max_sigma)
}

lmlikelihood_unknown_svd = function(y, xns, kern, cn, sigma = 1){
  n = length(y)
  K_nn = apply(xns, 1, function(x) apply(xns, 1, function(y) kern(x, y, cn)))
  svd_ = svd(K_nn)
  K_n_inv = svd_$v %*% diag(1/(sigma^2 + svd_$d)) %*% t(svd_$v)
  return( -(1/2)*t(y) %*% K_n_inv %*% y -(1/2)*sum(log(svd_$d + sigma^2)))
}

maximise_cn = function(y, xns, kern = mat_kern_1_2){
  n = length(y)
  gammas = seq(0.5, 2.0, by = 0.1)
  cns = n^{-1/(1+2*gammas)}
  results = sapply(cns, function(cn) lmlikelihood_unknown_svd(y, xns, kern, cn))
  max_cn = cns[which.max(results)]
  return(max_cn)
}


#### Below chunk makes the first table in our paper ####
{
#### Experiment 1, Investigate all for different values of alpha and gamma ####
{
x0 = 0.5
eta = 0.9
expected_cov = pnorm(qnorm((1+eta)/2), sd = 1/sqrt(2)) - pnorm(-qnorm((1+eta)/2), sd = 1/sqrt(2))
cat('Expected Coverage: ', expected_cov, '\n')
num_replicates = 500
ns = as.matrix(c(1000))
colnames(ns) = 'ns'
alphas = c(1, 0.5)
gammas = c(0.5, 0.5)
rand_design = FALSE
{
t1_total = Sys.time()

##### BM Experiment #####
# # m_fn = function(n, gamma){
# #   returns the three different values of m that we would like to testr
# #   return(c(max(2, 
# #               round(n^{1/(1+2*gamma)}/log(n))),
# #               min(n, round(n^{1/(1+2*gamma)}*log(n))),
# #               n))
# }
m_fn = function(n, gamma, alpha){
  #returns the three different values of m that we would like to testr
  return(c(min(n, round(n^{1/(1+2*gamma) * (2+alpha)/(1+alpha)})),
           n))
}
betas = (1-2*gammas)/(1+2*gammas)

smoothness_combos = cbind(alphas, gammas, betas)
param_combos = expand.grid.extra(as.matrix(ns), smoothness_combos)

bm_experiment = function(n, alpha, gamma, beta, eta, x0, num_replicates, rand_design=FALSE, noise = 'gaussian'){
  m_s = m_fn(n, gamma, alpha)
  cn = (n+1/2)^beta
  if(rand_design){
    print('Using random design')
    print('flag1')
    return(estimate_cov_len_bias_multiple_m_random_design(gamma_ = eta, m_s = m_s,
                                            x0_ = x0, n = n, alpha_ = alpha,
                                            num_replicates = num_replicates, cn_ = cn,
                                            kern = bm_kern, noise = noise))
  }else{
    print('Using fixed design...')
    xns = c(1:n)/(n+0.5)  
    K_nn = sapply(xns, function(x) sapply(xns, function(y) bm_kern(x, y, cn = cn)))
    svd_decomp = svd(K_nn)
    return(estimate_cov_len_bias_multiple_m(gamma_ = eta, m_s = m_s,
                                            x0_ = x0, svd_ = svd_decomp, alpha_ = alpha,
                                            num_replicates = num_replicates, cn_ = cn,
                                            kern = bm_kern,
                                            xns = xns, noise = noise))
  }
}

t1 = Sys.time()
bm_results = t(apply(param_combos, 1, function(p) 
  bm_experiment(n = p[1], alpha = p[2], gamma = p[3], beta = p[4],
                eta = eta, x0 = x0, num_replicates = num_replicates, rand_design = rand_design)
))
colnames(bm_results) = c(paste0('cov.m', c(1:2)), paste0('len.m', c(1:2)), paste0('bias.m', c(1:2)), paste0('nlpd.m', c(1:2)), paste0('sd.nlpd.m', c(1:2)))
bm_results = cbind(param_combos[,1:3], bm_results)
t2 = Sys.time() 
cat('Elapsed: ', difftime(t2, t1, units = 'secs'), '\n')

##### Matern Experiment #####
# ns = as.matrix(c(100, 500))
# colnames(ns) = 'ns'
# alphas = c(1,2)
# gammas = c(0.5, 1.5)
smoothness_combos = cbind(alphas, gammas)
colnames(smoothness_combos) = c('alphas','gammas')
param_combos = expand.grid.extra(as.matrix(ns), smoothness_combos)

matern_1_2_experiment = function(n, alpha, eta, x0, num_replicates, rand_design = FALSE, noise = 'gaussian'){
  m_s = m_fn(n, 1/2, alpha)
  if(rand_design){
    print('Using random design')
    return(estimate_cov_len_bias_multiple_m_random_design(gamma_ = eta, m_s = m_s,
                                                          x0_ = x0, n = n, alpha_ = alpha,
                                                          num_replicates = num_replicates, cn_ = 1,
                                                          kern = mat_kern_1_2, noise = noise))
  }else{
    print('Using fixed design...')
    xns = c(1:n)/(n+0.5)  
    K_nn = sapply(xns, function(x) sapply(xns, function(y) mat_kern_1_2(x, y)))
    svd_decomp = svd(K_nn)
    return(estimate_cov_len_bias_multiple_m(gamma_ = eta, m_s = m_s,
                                            x0_ = x0, svd_ = svd_decomp, alpha_ = alpha,
                                            num_replicates = num_replicates, cn_ = 1,
                                            kern = mat_kern_1_2, xns = xns, noise = noise))
  }
}
matern_3_2_experiment = function(n, alpha, eta, x0, num_replicates, rand_design = FALSE){
  m_s = m_fn(n, 3/2, alpha)
  if(rand_design){
    print('Using random design')
    return(estimate_cov_len_bias_multiple_m_random_design(gamma_ = eta, m_s = m_s,
                                                          x0_ = x0, n = n, alpha_ = alpha,
                                                          num_replicates = num_replicates, cn_ = 1,
                                                          kern = mat_kern_3_2))
  }else{
    print('Using fixed design...')
    xns = c(1:n)/(n+0.5)  
    K_nn = sapply(xns, function(x) sapply(xns, function(y) mat_kern_3_2(x, y)))
    svd_decomp = svd(K_nn)
    return(estimate_cov_len_bias_multiple_m(gamma_ = eta, m_s = m_s,
                                            x0_ = x0, svd_ = svd_decomp, alpha_ = alpha,
                                            num_replicates = num_replicates, cn_ = 1,
                                            kern = mat_kern_3_2, xns = xns))
  }
}
matern_5_2_experiment = function(n, alpha, eta, x0, num_replicates, rand_design = FALSE){
  m_s = m_fn(n, 5/2, alpha)
  if(rand_design){
    print('Using random design')
    return(estimate_cov_len_bias_multiple_m_random_design(gamma_ = eta, m_s = m_s,
                                                          x0_ = x0, n = n, alpha_ = alpha,
                                                          num_replicates = num_replicates, cn_ = 1,
                                                          kern = mat_kern_5_2))
  }else{
    print('Using fixed design...')
    xns = c(1:n)/(n+0.5)  
    K_nn = sapply(xns, function(x) sapply(xns, function(y) mat_kern_5_2(x, y)))
    svd_decomp = svd(K_nn)
    return(estimate_cov_len_bias_multiple_m(gamma_ = eta, m_s = m_s,
                                            x0_ = x0, svd_ = svd_decomp, alpha_ = alpha,
                                            num_replicates = num_replicates, cn_ = 1,
                                            kern = mat_kern_5_2, xns = xns))
  }
}
matern_experiment = function(n, alpha, gamma, eta, x0, num_replicates, rand_design = FALSE, noise = 'gaussian'){
  if(gamma == 1/2){
    return(matern_1_2_experiment(n, alpha, eta, x0, num_replicates, rand_design = rand_design, noise = noise))
  }else if(gamma == 3/2){
    return(matern_3_2_experiment(n, alpha, eta, x0, num_replicates, rand_design = rand_design, noise = noise))
  }else if(gamma == 5/2){
    return(matern_5_2_experiment(n, alpha, eta, x0, num_replicatesm, rand_design = rand_design, noise = noise))
  }else{
    print('gamma not matched')
  }
}

t1 = Sys.time()
mat_results = t(apply(param_combos, 1, function(p) 
  matern_experiment(n = p[1], alpha = p[2], gamma = p[3],
                eta = eta, x0 = x0, num_replicates = num_replicates, rand_design = rand_design)
))
colnames(mat_results) = c(paste0('cov.m', c(1:2)), paste0('len.m', c(1:2)), paste0('bias.m', c(1:2)), paste0('nlpd.m', c(1:2)), paste0('sd.nlpd.m', c(1:2)))
mat_results = cbind(param_combos, mat_results)
t2 = Sys.time() 
cat('Elapsed: ', difftime(t2, t1, units = 'secs'), '\n')

##### SE Experiment #####
# x0 = 0.5
# eta = 0.95
# ns = as.matrix(c(100, 500,1000))
# colnames(ns) = 'ns'
# alphas = c(1,2)
# gammas = c(0.5,1)
smoothness_combos = cbind(alphas, gammas)
param_combos = expand.grid.extra(as.matrix(ns), smoothness_combos)

se_experiment = function(n, alpha, gamma, eta, x0, num_replicates, rand_design = FALSE){
  cn = n^{-1/(1+2*gamma)}
  m_s = m_fn(n, gamma, alpha)
  if(rand_design){
    print('Using random design')
    return(estimate_cov_len_bias_multiple_m_random_design(gamma_ = eta, m_s = m_s,
                                                          x0_ = x0, n = n, alpha_ = alpha,
                                                          num_replicates = num_replicates, cn_ = cn,
                                                          kern = se_kern))
  }else{
    print('Using fixed design...')
    xns = c(1:n)/(n+0.5)  
    K_nn = sapply(xns, function(x) sapply(xns, function(y) se_kern(x, y, cn = cn)))
    svd_decomp = svd(K_nn)
    return(estimate_cov_len_bias_multiple_m(gamma_ = eta, m_s = m_s,
                                            x0_ = x0, svd_ = svd_decomp, alpha_ = alpha,
                                            num_replicates = num_replicates, cn_ = cn,
                                            kern = se_kern, xns = xns))
  }
}

t1 = Sys.time()
se_results = t(apply(param_combos, 1, function(p) 
  se_experiment(n = p[1], alpha = p[2], gamma = p[3],
                eta = eta, x0 = x0, num_replicates = num_replicates, rand_design = rand_design)
))
colnames(se_results) = c(paste0('cov.m', c(1:2)), paste0('len.m', c(1:2)), paste0('bias.m', c(1:2)), paste0('nlpd.m', c(1:2)), paste0('sd.nlpd.m', c(1:2)))
se_results = cbind(param_combos, se_results)
t2 = Sys.time() 

cat('Elapsed: ', difftime(t2, t1, units = 'secs'), '\n')

##### Make final results tables ####
bm_df = as.data.frame(bm_results)
bm_df['GP'] = rep('BM', dim(bm_df)[1])
mat_df = as.data.frame(mat_results)
mat_df['GP'] = rep('Matérn', dim(mat_df)[1])
se_df = as.data.frame(se_results)
se_df['GP'] = rep('SE', dim(se_df)[1])

bm_df = bm_df[,c(dim(bm_df)[2], 1:(dim(bm_df)[2] - 1))]
mat_df = mat_df[,c(dim(mat_df)[2], 1:(dim(mat_df)[2] - 1))]
se_df = se_df[,c(dim(se_df)[2], 1:(dim(se_df)[2] - 1))]

full_results = rbind(bm_df,
                     mat_df,
                     se_df)
colnames(full_results) = c('GP', 'n', 'alpha', 'gamma', 
                           'cm1', 'cm2',
                           'lm1', 'lm2',
                           'bm1', 'bm2', 
                           'nlpdm1', 'nlpdm2')

#### End of experiment ####
t2_total = Sys.time()
cat('Time Elapsed: ', difftime(t2_total, t1_total, units = 'mins'), 'mins')

alpha_1_results_500 = full_results[which(full_results$alpha == 1 & full_results$n == 500),]
alpha_2_results_500 = full_results[which(full_results$alpha == 2 & full_results$n == 500),]

alpha_1_results_1000 = full_results[which(full_results$alpha == 1 & full_results$n == 1000),]
alpha_2_results_1000 = full_results[which(full_results$alpha == 2 & full_results$n == 1000),]


full_results = full_results[order(full_results$alpha, decreasing=TRUE),]
}
print(xtable(full_results[,-2], 
             type = "latex", 
             digits = c(1,1,0,1,2,2,2,2,2,2)), 
      file = "FD_results.tex", 
      include.rownames=FALSE)
}
#### Experiment 2, same setup as above but with random design ####
{
x0 = 0.5
eta = 0.9
num_replicates = 100
ns = as.matrix(c(500))
colnames(ns) = 'ns'
alphas = c(1, 0.3)
gammas = c(0.5, 0.5)
rand_design = TRUE

betas = (1-2*gammas)/(1+2*gammas)
smoothness_combos = cbind(alphas, gammas, betas)
param_combos = expand.grid.extra(as.matrix(ns), smoothness_combos)
t1_total = Sys.time()

##### BM #####
t1 = Sys.time()
bm_results = t(apply(param_combos, 1, function(p) 
  bm_experiment(n = p[1], alpha = p[2], gamma = p[3], beta = p[4],
                eta = eta, x0 = x0, num_replicates = num_replicates, rand_design = rand_design)
))
colnames(bm_results) = c(paste0('cov.m', c(1:2)), paste0('len.m', c(1:2)), paste0('sd.len.m', c(1:2)), paste0('bias.m', c(1:2)), paste0('nlpd.m', c(1:2)),paste0('sd.nlpd.m', c(1:2)))
bm_results = cbind(param_combos[,c(1:3)], bm_results)
t2 = Sys.time() 
cat('Elapsed: ', difftime(t2, t1, units = 'secs'), '\n')
##### Matern #####
smoothness_combos = cbind(alphas, gammas)
colnames(smoothness_combos) = c('alphas','gammas')
param_combos = expand.grid.extra(as.matrix(ns), smoothness_combos)
t1 = Sys.time()
mat_results = t(apply(param_combos, 1, function(p) 
  matern_experiment(n = p[1], alpha = p[2], gamma = p[3],
                    eta = eta, x0 = x0, num_replicates = num_replicates, rand_design = rand_design)
))
colnames(mat_results) = c(paste0('cov.m', c(1:2)), paste0('len.m', c(1:2)), paste0('sd.len.m', c(1:2)), paste0('bias.m', c(1:2)), paste0('nlpd.m', c(1:2)),paste0('sd.nlpd.m', c(1:2)))
mat_results = cbind(param_combos, mat_results)
t2 = Sys.time() 
cat('Elapsed: ', difftime(t2, t1, units = 'secs'), '\n')
##### SE #####
t1 = Sys.time()
se_results = t(apply(param_combos, 1, function(p) 
  se_experiment(n = p[1], alpha = p[2], gamma = p[3],
                eta = eta, x0 = x0, num_replicates = num_replicates, rand_design = rand_design)
))
colnames(se_results) = c(paste0('cov.m', c(1:2)), paste0('len.m', c(1:2)), paste0('sd.len.m', c(1:2)), paste0('bias.m', c(1:2)), paste0('nlpd.m', c(1:2)),paste0('sd.nlpd.m', c(1:2)))
se_results = cbind(param_combos, se_results)
t2 = Sys.time() 
cat('Elapsed: ', difftime(t2, t1, units = 'secs'), '\n')

##### Make final results table #####
bm_df = as.data.frame(bm_results)
bm_df['GP'] = rep('BM', dim(bm_df)[1])
mat_df = as.data.frame(mat_results)
mat_df['GP'] = rep('Matérn', dim(mat_df)[1])
se_df = as.data.frame(se_results)
se_df['GP'] = rep('SE', dim(se_df)[1])
bm_df = bm_df[,c(dim(bm_df)[2], 1:(dim(bm_df)[2] - 1))]
mat_df = mat_df[,c(dim(mat_df)[2], 1:(dim(mat_df)[2] - 1))]
se_df = se_df[,c(dim(se_df)[2], 1:(dim(se_df)[2] - 1))]
full_results = rbind(bm_df,
                     mat_df,
                     se_df)
colnames(full_results) = c('GP', 'n', 'alpha', 'gamma', 
                           'cm1', 'cm2',
                           'lm1', 'lm2',
                           'sd.lm1', 'sd.lm2',
                           'bm1', 'bm2',
                           'nlpdm1', 'nlpdm2', 'sd.nlpdm1', 'sd.nlpdm2')

full_results[,c(9:10)] = round(full_results[,c(9:10)], 2)
formatted_results = full_results[,-c(9:10)]
formatted_results[,7] = paste0(round(full_results[,7], 2), ' (', full_results[,9], ')')
formatted_results[,8] = paste0(round(full_results[,8], 2), ' (', full_results[,10], ')')
}
##### End of experiment ####
t2_total = Sys.time()
cat('Time Elapsed: ', difftime(t2_total, t1_total, units = 'mins'), 'mins')

formatted_results = formatted_results[order(formatted_results$alpha, decreasing=TRUE),]

print(xtable(formatted_results[,-2], 
             type = "latex", 
             digits = c(1,1,1,1,2,2,2,2,2,2)), 
      file = "RD_results.tex", 
      include.rownames=FALSE)
}

#### Below chunk makes the second table in our paper ####
{
m_fn_multi_d = function(n, gamma, alpha, d){
  return(c(min(n, n^{d/(d+2*gamma) * (2+alpha)/(1+alpha)}),
           n
  ))
}
### Experiment 3  Multi D Rand Design ####
d = 10
cor = 0.1
alpha = 3
gamma = 1.5
n=2000

eta = 0.9
cn = 1
x0 = rep(0, d)
num_replicates = 10
ns = as.matrix(c(1000))
colnames(ns) = 'ns'
alphas = c(0.5, 0.7, 0.9, 0.9)
gammas = c(0.5, 0.5, 0.5, 0.5)
cors = c(NA, 0.0, 0.2, 0.5)
t1 = Sys.time()
multi_d_rand_design_experiment = function(gamma_, x0_, n, cor, alpha_, num_replicates = 100, cn_ = 1,
                                          kern = mat_kern_1_2){
  ms = m_fn_multi_d(n, gamma = 1/2, alpha = alpha_, d = length(x0_))
  estimate_cov_len_bias_multiple_m_multi_d_random_design(gamma=gamma_, m_s = ms, x0_ = x0_, n, cor,
                                                         alpha_ = alpha_, num_replicates = num_replicates, cn_ = cn_,
                                                         kern = mat_kern_1_2)
}

smoothness_combos = cbind(alphas, gammas, cors)
param_combos = expand.grid.extra(as.matrix(ns), smoothness_combos)
param_combos

results = t(apply(param_combos, 1, function(p) 
  multi_d_rand_design_experiment(gamma_=eta, x0_ = x0, n = p[1], cor = p[4],
                                 alpha_ = p[2], num_replicates = num_replicates, cn_ = cn,
                                 kern = mat_kern_1_2)))

results = cbind(smoothness_combos, results)
colnames(results) = c('alpha', 'gamma', 'rho', 
                      'cov.m1', 'cov.m2',
                      'len.m1', 'len.m2',
                      'sd.len.m1', 'sd.len.m2',
                      'rmse.m1', 'rmse.m2',
                      'nlpd.m1', 'nlpd.m2', 'sd.nlpd.m1', 'sd.nlpd.m2')
formatted_results = round(results[,-c(8:9)], 2)
formatted_results[,6] = paste0(round(results[,6], 2), ' (', round(results[,8],2), ')')
formatted_results[,7] = paste0(round(results[,7], 2), ' (', round(results[,9],2), ')')
designs = apply(formatted_results, 1, function(p) if(is.na(p[3])){'Uniform'}else{'Gaussian'})
formatted_results[,c(1:3)] = formatted_results[,c(3,1:2)]
colnames(formatted_results)[1:3] = colnames(formatted_results)[c(3,1:2)]
formatted_results = cbind(designs, formatted_results)

t2 = Sys.time()
cat('Elapsed: ', difftime(t2, t1, units = 'secs'))
print(xtable(formatted_results, 
             type = "latex", 
             digits = c(1,1,1,2,2,2,2,2,2,2,2)), 
      file = "multi_d_results.tex", 
      include.rownames=FALSE)


### Experiment 4, Investigate Synthetic Dataset ####
temp = na.omit(fread('Bias_correction_ucl.csv'))
n = 2000
size_dat = dim(temp)[1]
dat = temp[sample(c(1:size_dat), n, replace = FALSE), c(3:12)]
alpha = 1
eta = 0.9
num_replicates = 50
#normalise to fit in 0, 1
dat = apply(dat, 2, function(x) x/max(x))
n = dim(dat)[1]
d = dim(dat)[2]
x0 = rep(0.5, d)

matern_1_2_experiment_multi_d_synthetic = function(alpha, eta, x0, num_replicates, dat_){
  Xns = dat_
  cn = 1
  K_nn = apply(Xns, 1, function(x) apply(Xns, 1, function(y) mat_kern_1_2(x, y)))
  svd_decomp = svd(K_nn)
  m_s = m_fn_multi_d(n = dim(Xns)[1], gamma=1/2, alpha, dim(Xns)[2])
  return(estimate_cov_len_bias_multiple_m_multi_d(gamma_ = eta, m_s = m_s,
                                                  x0_ = x0, svd_ = svd_decomp, alpha_ = alpha,
                                                  num_replicates = num_replicates, cn_ = cn,
                                                  kern = mat_kern_1_2, xns = Xns) )
}
matern_1_2_experiment_multi_d_synthetic_emp_bayes = function(alpha, eta, x0, num_replicates, dat_, kern = mat_kern_1_2){
  Xns = dat_
  cn = 1
  K_nn = apply(Xns, 1, function(x) apply(Xns, 1, function(y) kern(x, y, cn)))
  svd_decomp = svd(K_nn)
  m_s = m_fn_multi_d(n = dim(Xns)[1], gamma=1/2, alpha, dim(Xns)[2])
  return(estimate_cov_len_bias_multiple_m_multi_d_emp_bayes(gamma_ = eta, m_s = m_s,
                                                  x0_ = x0, svd_ = svd_decomp, alpha_ = alpha,
                                                  num_replicates = num_replicates, cn_ = cn,
                                                  kern = kern, xns = Xns) )
}



matern_1_2_experiment_multi_d_synthetic(alpha, eta, x0, num_replicates, dat_ = dat)

temp = na.omit(fread('Bias_correction_ucl.csv'))
n = 2000
size_dat = dim(temp)[1]
dat = temp[sample(c(1:size_dat), n, replace = FALSE), c(3:12)]
alpha = 1
eta = 0.9
num_replicates = 500
#normalise to fit in 0, 1
dat = apply(dat, 2, function(x) x/max(x))
n = dim(dat)[1]
d = dim(dat)[2]
t1 = Sys.time()
matern_1_2_experiment_multi_d_synthetic_emp_bayes(alpha, eta, x0, num_replicates, dat_ = dat, kern = se_kern_multi_d)
t2 = Sys.time()
cat('Elapsed: ', difftime(t2, t1, units = 'secs'), '\n')


}



#### Below chunk makses the posterior plot ####
kern = bm_kern
n = 500
x_obs = seq(0, 1, length.out = n)
f0 = abs(x_obs - 0.5)
y = f0 + rnorm(length(x_obs), sd = 0.1)
gamma = 0.5
cn = (n+1/2)^{(1-2*gamma)/(1+2*gamma)}

xs = seq(0.2, 0.8, by = 0.01)
f0_xs = abs(xs - 0.5)
K_nn = sapply(x_obs, function(x) sapply(x_obs, function(y) bm_kern(x, y, cn)))
svd_decomp = svd(K_nn)

m_bad = n^{2/8}
m_good = n^{3/4}
vp_mean_variance(0.1, n, svd_decomp, bm_kern, x_obs, y, cn)
full_post = t(sapply(xs, function(x) vp_mean_variance(x, n, svd_decomp, bm_kern, x_obs, y, cn)))
bad_var_post = t(sapply(xs, function(x) vp_mean_variance(x, m_bad, svd_decomp, bm_kern, x_obs, y, cn)))
good_var_post = t(sapply(xs, function(x) vp_mean_variance(x, m_good, svd_decomp, bm_kern, x_obs, y, cn)))

full_post = data.frame(full_post)
colnames(full_post) = c('Mean', 'Variance')
full_post['Posterior'] = 'CFull Posterior'
full_post['x'] = xs

bad_var_post = data.frame(bad_var_post)
colnames(bad_var_post) = c('Mean', 'Variance')
bad_var_post['Posterior'] = 'ASGPR (m = n^{1/4})'
bad_var_post['x'] = xs

good_var_post = data.frame(good_var_post)
colnames(good_var_post) = c('Mean', 'Variance')
good_var_post['Posterior'] = 'BSGPR (Larger m = n^{3/4})'
good_var_post['x'] = xs


# good_var_post['Mean'] = good_var_post['Mean'] + 0.005
full_results = rbind(full_post, bad_var_post, good_var_post)

compute_full_results = function(){
  kern = bm_kern
  n = 500
  x_obs = seq(0, 1, length.out = n)
  f0 = abs(x_obs - 0.5)
  y = f0 + rnorm(length(x_obs), sd = 0.1)
  gamma = 0.5
  cn = (n+1/2)^{(1-2*gamma)/(1+2*gamma)}
  
  xs = seq(0.2, 0.8, by = 0.01)
  f0_xs = abs(xs - 0.5)
  K_nn = sapply(x_obs, function(x) sapply(x_obs, function(y) bm_kern(x, y, cn)))
  svd_decomp = svd(K_nn)
  
  m_bad = n^{2/8}
  m_good = n^{3/4}
  vp_mean_variance(0.1, n, svd_decomp, bm_kern, x_obs, y, cn)
  full_post = t(sapply(xs, function(x) vp_mean_variance(x, n, svd_decomp, bm_kern, x_obs, y, cn)))
  bad_var_post = t(sapply(xs, function(x) vp_mean_variance(x, m_bad, svd_decomp, bm_kern, x_obs, y, cn)))
  good_var_post = t(sapply(xs, function(x) vp_mean_variance(x, m_good, svd_decomp, bm_kern, x_obs, y, cn)))
  
  full_post = data.frame(full_post)
  colnames(full_post) = c('Mean', 'Variance')
  full_post['Posterior'] = 'CFull Posterior'
  full_post['x'] = xs
  
  bad_var_post = data.frame(bad_var_post)
  colnames(bad_var_post) = c('Mean', 'Variance')
  bad_var_post['Posterior'] = 'ASGPR (m = n^{1/4})'
  bad_var_post['x'] = xs
  
  good_var_post = data.frame(good_var_post)
  colnames(good_var_post) = c('Mean', 'Variance')
  good_var_post['Posterior'] = 'BSGPR (Larger m = n^{3/4})'
  good_var_post['x'] = xs
  
  
  # Uncomment below if want to shift the variational posterior which is close to the full GP
  # good_var_post['Mean'] = good_var_post['Mean'] + 0.005
  full_results = rbind(full_post, bad_var_post, good_var_post)
  return(list(full_results = full_results, x_obs = x_obs, f0 = f0, y = y,
                                            xs = xs, f0_xs = f0_xs))
}

temp = compute_full_results()
full_results = temp$full_results
x_obs = temp$x_obs
y = temp$y

ggplot(data = full_results) +
  geom_line(data = full_results, aes(x = x, y = Mean, color = Posterior), alpha = 1) +
  geom_ribbon(data= full_results, aes(x = x, ymin = Mean - 1.96*Variance, ymax = Mean + 1.96*Variance, fill = Posterior), color = NA, alpha = 0.1) +
  geom_line(data = data.frame(x_obs, f0), aes(x = x_obs, y = f0), color = 'black', alpha = 0.75) +
  facet_wrap(~Posterior, nrow=1)+
  xlim(0.2, 0.8) + ggtitle('Comparison of the Posteriors') +
  ylim(-0.1, 0.45) + 
  scale_color_discrete(labels=c('SGPR (m = 5)', 'SGPR (m = 106)', 'GP')) +
  scale_fill_discrete(labels=c('SGPR (m = 5)', 'SGPR (m = 106)', 'GP')) +
  theme(legend.position="bottom", strip.background = element_blank(), strip.text.x = element_blank())

ggsave('neurips_posterior_comparison.pdf', units = 'in', width = 8, height = 5)

ggplot(data = full_results) +
  geom_line(data = full_results, aes(x = x, y = Mean, color = Posterior), alpha = 1) +
  geom_ribbon(data= full_results, aes(x = x, ymin = Mean - 1.96*Variance, ymax = Mean + 1.96*Variance, fill = Posterior), color = NA, alpha = 0.1) +
  geom_line(data = data.frame(x_obs, f0), aes(x = x_obs, y = f0), color = 'black', alpha = 0.75) +
  xlim(0.2, 0.8) + ggtitle('Comparison of the Posteriors') +
  ylim(-0.1, 0.45) + 
  scale_color_discrete(labels=c('SGPR (m = 5)', 'SGPR (m = 106)', 'GP')) +
  scale_fill_discrete(labels=c('SGPR (m = 5)', 'SGPR (m = 106)', 'GP')) +
  theme(legend.position="bottom", strip.background = element_blank(), strip.text.x = element_blank())

ggsave('neurips_posterior_comparison_stacked.pdf', units = 'in', width = 8, height = 5)

