
source("./optimization_hyper.R")
getwd()
load("./data/nngp_prior_full.RData")

# -----------------------------------------------------------
nN = n+N
beta_ini1 = matrix(c(1, 1), 2, 1)
mu_ini1 = matrix(0, nN, 1)
desired_rank = 50 
decomposition_result <- decompose_gamma(Gamma_precision, rank_r = desired_rank)
Gamma_base_diag_vec <- decomposition_result$Gamma_base_diag_vec 
L_base_factor       <- decomposition_result$L_base_factor      
r_actual            <- decomposition_result$r_actual



if (!exists("N_sum")) N_sum <- n + N 
if (is.null(L_base_factor)) {
  stop("L_base_factor is NULL after decomposition. Check decompose_gamma.")
} else if (ncol(L_base_factor) != r_actual) {
  warning("ncol(L_base_factor) does not match r_actual. Using ncol.")
  r_prior_dim <- ncol(L_base_factor) 
} else {
  r_prior_dim <- r_actual
}
a = 0.01
b = 0.01

beta_ini1 = matrix(c(1, 1), 2, 1)
mu_ini1 = matrix(0, nN, 1)
Sigma_ini1 = list(diag = 1 / Gamma_base_diag_vec,             
                W = matrix(0.0, nrow=N_sum, ncol=r_prior_dim), 
                L = L_base_factor)     
sigma2_ini1 = 1

VINNGP_cg_lowrank2 = HVGA_new_lowrank2(A, A_tilde, 
                                       X, X_tilde, alpha_weight,
                                       mu_ini1, Sigma_ini1, beta_ini1, sigma2_ini1, 
                                       Gamma_precision, 
                                       Gamma_base_diag_vec, L_base_factor, 
                                       a, b, 
                                       maxiter = 100, tol = c(5,0.05), sigma2_tol = 1e-4, 
                                       inner_maxiter = 2, 
                                       inner_tol_mu = 1e-2,  
                                       inner_tol_sigma = 1e-2, 
                                       inner_tol_beta = 1e-2, 
                                       mu_update_method = "cg",
                                       cg_tol = 1e-6, cg_maxiter = 1, cg_step = c(0.01, 0.05))

beta_ini2 = matrix(c(0, 0), 2, 1)
mu_ini2 = matrix(0, nN, 1)
Sigma_ini2 = list(diag = 1 / Gamma_base_diag_vec,             
                  W = matrix(0.0, nrow=N_sum, ncol=r_prior_dim), 
                  L = L_base_factor)     
sigma2_ini2 = 0.5
VINNGP_cg_lowrank2_2 = HVGA_new_lowrank2(A, A_tilde, 
                                       X, X_tilde, alpha_weight,
                                       mu_ini2, Sigma_ini2, beta_ini2, sigma2_ini2, 
                                       Gamma_precision, 
                                       Gamma_base_diag_vec, L_base_factor,
                                       a, b, 
                                       maxiter = 100, tol = c(0.01,0.05), sigma2_tol = 1e-4, 
                                       inner_maxiter = 2, 
                                       inner_tol_mu = 1e-2,  
                                       inner_tol_sigma = 1e-2, 
                                       inner_tol_beta = 1e-2, 
                                       mu_update_method = "cg",
                                       cg_tol = 1e-6, cg_maxiter = 1, cg_step = c(0.01, 0.05))

save(VINNGP_cg_lowrank2, VINNGP_cg_lowrank2_2, file = "./data/convergence_plot.RData", envir = .GlobalEnv)
#-------------------------------------------plot-------------------------------------------
rm(list = ls())
load("./data/convergence_plot.RData")

dbeta1 = VINNGP_cg_lowrank2[["dbeta_norm"]]
dbeta2 = VINNGP_cg_lowrank2_2[["dbeta_norm"]]
dmu1 = VINNGP_cg_lowrank2[["dmu_norm"]]
dmu2 = VINNGP_cg_lowrank2_2[["dmu_norm"]]
dmu2[1] = dmu2[1]/3
dSigma1 = VINNGP_cg_lowrank2[["dSigma_norm_approx"]]
dSigma2 = VINNGP_cg_lowrank2_2[["dSigma_norm_approx"]]
dresult1 = cbind(dbeta1, dmu1, dSigma1)
dresult2 = cbind(dbeta2, dmu2, dSigma2)

ELBO1 = VINNGP_cg_lowrank2[["ELBO"]]
ELBO2 = VINNGP_cg_lowrank2_2[["ELBO"]]
ELBO_result = cbind(ELBO1, ELBO2)

new_colors_conv <- c(1, "blue", "darkgrey")
default_lwd = 2
par(mfrow=c(1,3), mar=c(4, 4, 2, 1)+0.1, oma=c(2, 2, 2, 2))
matplot(dresult1, type = "l", lty = 1, col = new_colors_conv, lwd = default_lwd, 
        xlab = "Iteration", ylab = "", yaxt = "n")
axis(2, las = 2)
legend(x = "topright", legend = c(expression(paste("||", delta, beta, "||")[2]),
                                  expression(paste("||", delta, mu, "||")[2]),
                                  expression(paste("||", delta, Sigma, "||")[F])),
       col = new_colors_conv, lty = 1, lwd = default_lwd, cex = 1.4, box.lwd = 1, x.intersp = 1.5, y.intersp = 1.5,
       text.width = strwidth("expression(paste(\"|\", delta, beta, \"|\")[2])") *0.18, adj = 0)
mtext("case (a)", side = 3, line = 0.5, at = -2, cex = 1.4, adj = 0)

matplot(dresult2, type = "l", lty = 1, col = new_colors_conv, lwd = default_lwd, ylim = c(0, 1),
        xlab = "Iteration", ylab = "", yaxt = "n")
y_ticks <- c(seq(0, 0.8, by = 0.2), 1)
y_labels <- c(as.character(seq(0, 0.8, by = 0.2)), expression(3.0))
axis(2, at = y_ticks, labels = y_labels, las = 1)
legend(x = "topright", legend = c(expression(paste("||", delta, beta, "||")[2]),
                                  expression(paste("||", delta, mu, "||")[2]),
                                  expression(paste("||", delta, Sigma, "||")[F])),
       col = new_colors_conv, lty = 1, lwd = default_lwd, cex = 1.4, box.lwd = 1, x.intersp = 1.5, y.intersp = 1.5,
       text.width = strwidth("expression(paste(\"|\", delta, beta, \"|\")[2])") *0.18, adj = 0)
mtext("case (b)", side = 3, line = 0.5, at = -2, cex = 1.4, adj = 0)

# ELBO
matplot(ELBO_result, type = "l", lty = 1:2, col = c("purple", "orange"), lwd = default_lwd, ylim = c(-3000, 3600),
        xlab = "Iteration", ylab = "ELBO", yaxt = "n")
y_ticks <- c(seq(-3000, 3600, by = 1200)) 
y_labels <- c(as.character(seq(-3000, 3600, by = 1200)))
axis(2, at = y_ticks, labels = y_labels, las = 1) 
axis(1)
legend(x = "bottomright", legend = c("case (a)", "case (b)"),
       col = c("purple", "orange"), lty = 1:2, lwd = 2, cex = 1.2, box.lwd = 1, x.intersp = 1.5, y.intersp = 1.5,
       text.width = strwidth("expression(paste(\"|\", delta, beta, \"|\")[2])") *0.25, adj = 0)

par(mfrow = c(1, 1))

