
rm(list = ls())

source("./src/dual_mesh_loc.R")

# Load necessary libraries
library(RSpectra) # For decompose_gamma if used later
library(stats)   # For optim, lgamma

# N: the number of data pattern
# n: the number of basis functions (the number of vertices in the mesh's triangles)
# p: the number of data points used for approximate integration (the number of polygons in the dual mesh)
# N_sum: N+n

# @ A (N*N_sum) -> Assumed implicit via indices n, N
# @ A_tidle (n*N_sum) -> Assumed implicit via indices n, N
# @ X (N*m): a matrix composed of the values of m covariates for N data patterns.
# @ X_tilde (n*m): a matrix composed of the values of m covariates for p integration data points.
# @ alpha_weight (n*1): Weights used for approximate integration, obtained when constructing the dual mesh.
# @ beta_c (m*1): the coefficients corresponding to m covariates.
# @ mu (N_sum*1): posterior mean
# @ Sigma_k_list: List representing posterior covariance (low-rank format)
# @ sigma2: Variance hyperparameter
# @ Gamma_base_precision: The base precision matrix (unscaled by sigma2)
# @ Gamma_base_diag_vec: Diagonal of Gamma_base_precision
# @ L_base_factor: Low-rank factor L from Gamma_base_precision = Diag + L L^T
# @ a, b: Hyperparameters for Inverse-Gamma prior on sigma2: IG(a, b)

##################################################################################

# trace of sparse matrix
sparsemat_tr = function(mat){
  return(sum(diag(mat)))
}

# log(det)
log_det = function(mat){
  return(Matrix::determinant(mat, logarithm = TRUE)$modulus[1])
}


exp_WeightedVec_lowrank = function(alpha_weight, mu, Sigma_k_list, beta_c, X_tilde,
                                   L_base_factor, # MUST be the base factor L from Gamma = G_diag + L L^T
                                   sigma2_k){    # Current sigma^2 value
  n = nrow(X_tilde)
  N_sum = length(mu)
  r_prior = ncol(L_base_factor)
  I_r = Matrix::Diagonal(r_prior)
  
  # Ensure correct dimensions
  if (length(Sigma_k_list$diag) != N_sum) stop("Dimension mismatch: Sigma_k_list$diag")
  if (nrow(L_base_factor) != N_sum) stop("Dimension mismatch: L_base_factor")
  if (sigma2_k <= 0) stop("sigma2_k must be positive in exp_WeightedVec_lowrank")
  
  mu_n = mu[1:n]
  D_inv_vec = Sigma_k_list$diag # = (H_tilde_diag)^{-1}
  
  # Calculate diagonal of Sigma_k = D_inv - D_inv L' (I+M)^{-1} L'^T D_inv
  # M = L'^T D_inv L' = (1/sigma2_k) * L_base^T D_inv L_base
  M = (1/sigma2_k) * Matrix::crossprod(L_base_factor, D_inv_vec * L_base_factor)
  
  # Middle_inv = solve(I + M)
  Middle_inv = tryCatch({
    Matrix::solve(I_r + M)
  }, error = function(e){
    warning("Matrix solve failed for (I+M) in exp_WeightedVec_lowrank. Check M's condition number. Using pseudo-inverse.", call. = FALSE)
    # Fallback or error handling, e.g., using MASS::ginv or stopping
    MASS::ginv(as.matrix(I_r + M)) # Ensure MASS is loaded or handle differently
  })
  
  # L' = L_base / sqrt(sigma2_k)
  L_prime = (1/sqrt(sigma2_k)) * L_base_factor
  
  # Precompute V = L' %*% Middle_inv
  V = L_prime %*% Middle_inv  # N_sum x r
  
  # Calculate diag(Sigma_k)_i = D_inv_i - D_inv_i^2 * sum_j V[i,j] * L_prime[i,j]
  # More efficiently: D_inv_i - D_inv_i^2 * rowSums(V * L_prime)
  Correction_term = D_inv_vec^2 * rowSums(V * L_prime)
  Sigma_k_diag_elems = D_inv_vec - Correction_term
  
  # Check for negative variances (shouldn't happen if stable)
  if(any(Sigma_k_diag_elems <= 0)){
    warning("Non-positive diagonal elements calculated for Sigma_k in exp_WeightedVec_lowrank. Clamping.")
    Sigma_k_diag_elems = pmax(Sigma_k_diag_elems, .Machine$double.eps)
  }
  
  Sigma_nn_diag = Sigma_k_diag_elems[1:n]
  
  if (!is.matrix(beta_c)) beta_c = matrix(beta_c, ncol=1)
  exp_arg = X_tilde %*% beta_c + mu_n + 0.5 * Sigma_nn_diag
  
  return(as.vector(alpha_weight) * exp(exp_arg))
}

# # Calculates trace(Gamma_base * Sigma_k) using the low-rank representation of Sigma_k
# # Sigma_k = D_inv - (1/sigma2) * (D_inv L_base) * (I + M)^{-1} * (L_base^T D_inv)
# # where M = (1/sigma2) * L_base^T * D_inv * L_base
# # Gamma_base = Diag(Gamma_base_diag_vec) + L_base_factor * t(L_base_factor) (approx)
# trace_Gamma_Sigma_lowrank <- function(Gamma_base_diag_vec, L_base_factor, Sigma_k_list, sigma2_k) {
#   
#   D_inv_vec = Sigma_k_list$diag
#   W = Sigma_k_list$W
#   N_sum = length(D_inv_vec)
#   r_prior = ncol(L_base_factor)
#   I_r = Matrix::Diagonal(r_prior)
#   
#   # Ensure components are valid
#   if(length(Gamma_base_diag_vec) != N_sum) stop("Dimension mismatch: Gamma_base_diag_vec")
#   if(nrow(L_base_factor) != N_sum) stop("Dimension mismatch: L_base_factor")
#   
#   # Term 1: trace(Gamma_diag %*% D_inv)
#   term1 = sum(Gamma_base_diag_vec * D_inv_vec)
#   
#   # Term 2: - trace(Gamma_diag %*% W %*% t(L * D_inv))
#   # = - trace(t(L*D_inv) %*% Gamma_diag %*% W)
#   Term1_k = D_inv_vec * L_base_factor
#   term2 = -sum(Term1_k * (Gamma_base_diag_vec * W)) # Efficient trace calculation
#   
#   
#   # Precompute M and Middle_inv for the current Sigma_k
#   M = (1/sigma2_k) * Matrix::crossprod(L_base_factor, D_inv_vec * L_base_factor)
#   term3 = sum(Matrix::diag(M))
#   
#   # Term 4: - trace(L L^T %*% W %*% t(L*D_inv))
#   # = - trace(t(L*D_inv) %*% L %*% L^T %*% W)
#   Lt_W = Matrix::crossprod(L_base_factor, W) # r x r
#   term4 = -sum(Term1_k * (L_base_factor %*% Lt_W)) # Efficient trace
#   
#   return(term1 + term2 + term3 + term4)
# }

# Calculates trace(Gamma_base * Sigma_k) using the low-rank representation of Sigma_k
# Gamma_base = G_diag + L L^T
# Sigma_k = D_inv - W L'^T D_inv
# trace(Gamma_base Sigma_k) = trace(G_diag D_inv) + trace(L L^T D_inv)
#                           - trace(G_diag W L'^T D_inv) - trace(L L^T W L'^T D_inv)
# Uses cyclic property: trace(A B) = trace(B A)
trace_Gamma_Sigma_lowrank <- function(Gamma_base_diag_vec, L_base_factor, Sigma_k_list, sigma2_k) {
  
  D_inv_vec = Sigma_k_list$diag
  W = Sigma_k_list$W          # W = D_inv L' (I+M)^-1
  L = L_base_factor     # Base factor L
  s2 = sigma2_k
  
  N_sum = length(D_inv_vec)
  r_prior = ncol(L)
  
  # Ensure components are valid
  if(length(Gamma_base_diag_vec) != N_sum) stop("Dimension mismatch: Gamma_base_diag_vec")
  if(nrow(L) != N_sum) stop("Dimension mismatch: L_base_factor")
  if(nrow(W) != N_sum || ncol(W) != r_prior) stop("Dimension mismatch: W")
  if(s2 <= 0) stop("sigma2_k must be positive in trace_Gamma_Sigma_lowrank")
  
  # Term 1: trace(G_diag D_inv)
  tr_Gdiag_Dinv = sum(Gamma_base_diag_vec * D_inv_vec)
  
  # Term 2: trace(L L^T D_inv) = trace(L^T D_inv L)
  LTDinvL = Matrix::crossprod(L, D_inv_vec * L) # r x r
  tr_LLT_Dinv = sum(Matrix::diag(LTDinvL)) # Use Matrix::diag for sparse matrix
  
  # Term 3: - trace(G_diag W L'^T D_inv) = - trace(L'^T D_inv G_diag W)
  # L' = L / sqrt(s2)
  # = - (1/sqrt(s2)) * trace(L^T D_inv G_diag W)
  Term_T3_Inner = D_inv_vec * Gamma_base_diag_vec * L # N_sum x r, (D_inv G_diag L)
  # trace = sum of element-wise product of t(Term_T3_Inner) and W
  tr_Gdiag_W_LT_Dinv_Lprime = (1/sqrt(s2)) * sum(Term_T3_Inner * W) # Efficient trace calculation
  
  # Term 4: - trace(L L^T W L'^T D_inv) = - trace(L'^T D_inv L L^T W)
  # = - (1/sqrt(s2)) * trace(L^T D_inv L L^T W)
  # = - (1/sqrt(s2)) * trace( (L^T D_inv L) %*% (L^T W) )
  LTW = Matrix::crossprod(L, W) # r x r
  # trace(A B) = sum(t(A) * B) element-wise product
  tr_LLT_W_LT_Dinv_Lprime = (1/sqrt(s2)) * sum(t(LTDinvL) * LTW)
  
  # Combine terms
  trace_val = tr_Gdiag_Dinv + tr_LLT_Dinv - tr_Gdiag_W_LT_Dinv_Lprime - tr_LLT_W_LT_Dinv_Lprime
  
  # Check for non-finite results (can happen if W or D_inv are problematic)
  if (!is.finite(trace_val)) {
    warning("Non-finite result in trace_Gamma_Sigma_lowrank. Check inputs.", call.=FALSE)
    # Return NA or handle appropriately
    return(NA)
  }
  
  return(trace_val)
}

# Log determinant of Sigma_k
# Based on det(Sigma) = det(D_inv) / det(I+M)
# where M = (1/sigma2) * L_base^T * D_inv * L_base
log_det_Sigma_lowrank <- function(Sigma_k_list, L_base_factor, sigma2_k) {
  
  D_inv_vec = Sigma_k_list$diag
  # L_base = Sigma_k_list$L # Assumes L_base is stored here
  r_prior = ncol(L_base_factor)
  I_r = Matrix::Diagonal(r_prior)
  
  # Check for non-positive in D_inv_vec
  if (any(D_inv_vec <= 0)) {
    warning("Non-positive values in Sigma_k$diag for log_det calculation.")
    return(NA) # Cannot compute log-determinant
  }
  log_det_D_inv = sum(log(D_inv_vec))
  
  # Recompute M based on current D_inv and sigma2_k
  M = (1/sigma2_k) * Matrix::crossprod(L_base_factor, D_inv_vec * L_base_factor)
  det_obj_I_plus_M = Matrix::determinant(I_r + M, logarithm = TRUE) # Use log directly
  
  # Check determinant validity
  if (!is.finite(det_obj_I_plus_M$modulus[1])) {
    warning("Log determinant of (I+M) is non-finite.")
    return(NA)
  }
  # Check sign - determinant must be positive
  if (det_obj_I_plus_M$sign <= 0) {
    warning("Determinant of (I+M) is non-positive.")
    # Check if M is numerically problematic (e.g., large condition number)
    # cond_num <- kappa(I_r + M) # Requires Matrix package calculation
    # message(paste("Condition number of I+M:", cond_num))
    return(NA)
  }
  
  
  log_det_I_plus_M = det_obj_I_plus_M$modulus[1]
  
  log_det_Sigma = log_det_D_inv - log_det_I_plus_M
  
  return(log_det_Sigma)
}


# ELBO Calculation based on formula (12) interpretation and standard VI
# Assumes prior precision is sigma^{-2} * Gamma_base
# Assumes prior on sigma2 is IG(a, b)
F_ELBO_lowrank = function(X, X_tilde, alpha_weight, mu, Sigma_k_list, beta_c, sigma2_k,
                          Gamma_base_precision, # Base precision matrix (Needed for mu^T G mu)
                          Gamma_base_diag_vec,  # Diagonal of base precision
                          L_base_factor,        # Low-rank factor of base precision
                          log_det_gamma_base,   # Precomputed log(det(Gamma_base))
                          a, b) {               # IG prior parameters for sigma2
  n = nrow(X_tilde)
  N = nrow(X)
  N_sum = n + N
  mu_N = mu[(n+1):N_sum]
  
  # Likelihood related terms (using updated exp_WeightedVec_lowrank)
  # Pass L_base_factor and sigma2_k to the function
  exp_w_vec = exp_WeightedVec_lowrank(alpha_weight, mu, Sigma_k_list, beta_c, X_tilde,
                                      L_base_factor, sigma2_k)
  # Check for non-finite values from exp_WeightedVec_lowrank
  if (any(!is.finite(exp_w_vec))) {
    warning("Non-finite values in exp_w_vec within ELBO calculation.", call.=FALSE)
    exp_w_vec[!is.finite(exp_w_vec)] <- 0 # Or handle more robustly
  }
  log_likelihood_term = sum(X %*% beta_c) + sum(mu_N) - sum(exp_w_vec)
  
  # Prior terms for mu (scaled by sigma2)
  # E_q[ log p(mu | sigma2) ] ~ -1/(2*sigma2) * E_q[mu^T Gamma_base mu] - 1/2 log det(sigma2 * inv(Gamma_base))
  # E_q[mu^T Gamma_base mu] = mu_k^T Gamma_base mu_k + tr(Gamma_base Sigma_k)
  muT_G_mu = as.numeric(Matrix::crossprod(mu, Gamma_base_precision %*% mu))
  tr_G_Sigma = trace_Gamma_Sigma_lowrank(Gamma_base_diag_vec, L_base_factor, Sigma_k_list, sigma2_k)
  
  # Check for NA from trace calculation
  if (is.na(tr_G_Sigma)) {
    warning("trace_Gamma_Sigma calculation failed in ELBO.", call.=FALSE)
    return(NA)
  }
  
  prior_energy_expect = muT_G_mu + tr_G_Sigma
  log_prior_mu_energy_term = -1 / (2 * sigma2_k) * prior_energy_expect
  
  # Log determinant of prior precision term: - 1/2 log det(sigma^2 * Gamma_base^{-1})
  # = - 1/2 * (N_sum * log(sigma2_k) - log_det_gamma_base)
  log_prior_mu_det_term = -0.5 * (N_sum * log(sigma2_k) - log_det_gamma_base)
  
  # Entropy term H(q) = 0.5 * log det(Sigma_k) + 0.5 * N_sum * (1 + log(2*pi))
  # We only need the log det part for optimization: 0.5 * log det(Sigma_k)
  log_det_sigma_val = log_det_Sigma_lowrank(Sigma_k_list, L_base_factor, sigma2_k)
  
  if (is.na(log_det_sigma_val)) {
    warning("Log_det computation failed in ELBO.", call.=FALSE)
    return(NA)
  }
  # Standard entropy term
  entropy_term = 0.5 * log_det_sigma_val
  
  # Prior terms for sigma2 ~ IG(a, b) -> log p(sigma2)
  # log p(sigma2) = - (a + 1) log sigma2 - b / sigma2 + a log b - lgamma(a)
  log_prior_sigma2_term = - (a + 1) * log(sigma2_k) - b / sigma2_k
  # Constants from IG prior (often dropped, but include for completeness)
  # ig_const = a * log(b) - lgamma(a) # Use b from input
  
  # Constant term n/2 from Eq 12 (origin unclear, but include per formula)
  const_term_n_half = n / 2.0
  
  # Assemble ELBO based on F = E_q[log p(Y, mu, sigma2)] - E_q[log q(mu)]
  # F = E_q[log p(Y|mu,beta)] + E_q[log p(mu|sigma2)] + log p(sigma2) - E_q[log q(mu)]
  # F = log_likelihood_term + (log_prior_mu_energy_term + log_prior_mu_det_term)
  #     + log_prior_sigma2_term + ig_const + entropy_term + N_sum/2*(1+log(2pi))
  # Drop constants N_sum/2*(1+log(2pi)) and ig_const if desired. Add n/2 per Eq 12.
  ELBO = log_likelihood_term + log_prior_mu_energy_term + log_prior_mu_det_term +
    entropy_term + log_prior_sigma2_term + const_term_n_half # + ig_const
  
  if (!is.finite(ELBO)) {
    warning("Non-finite ELBO computed.", call.=FALSE)
    # You might want to inspect the components here if debugging
    # print(paste("log_lik:", log_likelihood_term))
    # print(paste("log_prior_E:", log_prior_mu_energy_term))
    # print(paste("log_prior_D:", log_prior_mu_det_term))
    # print(paste("entropy:", entropy_term))
    # print(paste("log_p_sig2:", log_prior_sigma2_term))
    return(NA)
  }
  
  return(as.numeric(ELBO))
}


# Helper function for L-BFGS: Calculate Negative ELBO w.r.t. mu (ignoring constants w.r.t mu)
NegELBO_mu_lowrank <- function(mu, Sigma_k_list, beta_c_k, sigma2_k, Gamma_base_precision,
                               X_tilde, alpha_weight, n, N, tA_eN_vec, L_base_factor) {
  
  mu = as.matrix(mu)
  exp_w_vec = exp_WeightedVec_lowrank(alpha_weight, mu, Sigma_k_list, beta_c_k, X_tilde,
                                      L_base_factor, sigma2_k)
  mu_N = mu[(n + 1):(n + N), 1, drop=FALSE]
  
  likelihood_term1 = sum(mu_N)
  likelihood_term3 = -sum(exp_w_vec)
  # Prior term depends on sigma2_k
  quardratic_pen = -1 / (2 * sigma2_k) * as.numeric(Matrix::crossprod(mu, Gamma_base_precision %*% mu))
  
  neg_elbo = -(likelihood_term1 + likelihood_term3 + quardratic_pen)
  return(neg_elbo)
}

# Helper function for L-BFGS: Calculate Gradient of Negative ELBO w.r.t. mu
grad_NegELBO_mu_lowrank <- function(mu, Sigma_k_list, beta_c_k, sigma2_k, Gamma_base_precision,
                                    X_tilde, alpha_weight, n, N, tA_eN_vec, L_base_factor) {
  mu = as.matrix(mu)
  exp_w_vec = exp_WeightedVec_lowrank(alpha_weight, mu, Sigma_k_list, beta_c_k, X_tilde,
                                      L_base_factor, sigma2_k) # Pass args
  
  # Calculate f_mu (gradient of positive ELBO)
  f_mu_lik_term2 = matrix(c(exp_w_vec, rep(0, N)), ncol=1) # N_sum x 1
  f_mu_lik_term3 = matrix(tA_eN_vec, ncol=1) # N_sum x 1 (precomputed t(A)%*%e_N)
  
  # Prior gradient term (gradient of -1/(2*sigma2)*mu^T*G*mu is -1/sigma2*G*mu)
  f_mu_prior_term = -(1 / sigma2_k) * (Gamma_base_precision %*% mu)
  
  f_mu = f_mu_lik_term3 - f_mu_lik_term2 + f_mu_prior_term # grad(Lik) + grad(log p(mu))
  
  # Gradient for optim (minimization) is -f_mu (gradient of negative ELBO)
  grad = -f_mu
  return(as.vector(grad))
}

# Helper function for CG: Calculate Hessian-vector product (Hessian_of_NegELBO %*% v)
# Hessian of NegELBO = - Hessian of ELBO
# Hessian of ELBO = H_lik + H_prior
# H_lik = block_diag(exp_w_vec, 0)
# H_prior = -1/sigma2 * Gamma_base_precision
# HessianVecProd calculates (-H_lik - H_prior) %*% v = -(H_lik + H_prior) %*% v
HessianVecProd_mu_lowrank <- function(v, mu_k, Sigma_k_list, beta_c_k, sigma2_k, Gamma_base_precision,
                                      X_tilde, alpha_weight, n, N, L_base_factor) {
  v = as.matrix(v)
  N_sum = n + N
  if(nrow(v) != N_sum) stop("Vector v has incorrect dimension for HessianVecProd")
  
  exp_w_vec_k = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                        L_base_factor, sigma2_k) 
  
  # Term from prior Hessian: (-H_prior) %*% v = (1/sigma2 * Gamma_base_precision) %*% v
  term1 = (1 / sigma2_k) * (Gamma_base_precision %*% v)
  
  # Term from likelihood Hessian: (-H_lik) %*% v = -(block_diag(...) %*% v)
  v_n = v[1:n, 1, drop=FALSE]
  prod_part = as.vector(exp_w_vec_k) * v_n
  term2 = matrix(c(prod_part, rep(0, N)), ncol=1) # This is H_lik %*% v
  
  # Result is term1 + term2 because H_negELBO = H_prior_scaled + H_lik
  result = term1 + term2
  return(result) # Return as matrix
}


# Basic Conjugate Gradient Solver (copied from thought process, seems reasonable)
cg_solver <- function(Ax_func, b, x0, tol = 1e-5, maxiter = 100, ...) {
  x <- as.matrix(x0); b <- as.matrix(b)
  r <- b - Ax_func(x, ...); p <- r
  rsold <- sum(r^2)
  if (sqrt(rsold) < tol) return(list(x = x, iterations = 0, converged = TRUE))
  
  iter = 0; converged = FALSE
  for (i in 1:maxiter) {
    iter = i; Ap <- Ax_func(p, ...); pAp = sum(p * Ap)
    if (!is.finite(pAp) || pAp == 0) {
      warning("CG: pAp is zero or non-finite. Stopping."); break }
    # --- 新增检查 ---
    if (!is.finite(pAp) || pAp <= tol * norm(p, "2") * norm(Ap, "2")) { # 添加检查，tol 可以是 sqrt(.Machine$double.eps)
      warning("CG: Matrix may not be positive definite or pAp is numerically zero/negative/non-finite. Stopping.")
      converged = FALSE # 标记未收敛
      break # 退出循环
    }
    # --- 结束检查 ---
    alpha <- rsold / pAp
    x <- x + alpha * p; r <- r - alpha * Ap
    rsnew <- sum(r^2)
    if (sqrt(rsnew) < tol) { converged = TRUE; break }
    if (!is.finite(rsold) || rsold == 0) {
      warning("CG: rsold is zero or non-finite. Stopping."); break }
    p <- r + (rsnew / rsold) * p; rsold <- rsnew
  }
  if (i == maxiter && !converged) warning(paste("CG did not converge within", maxiter, "iterations."))
  return(list(x = x, iterations = iter, converged = converged))
}


# Main algorithm incorporating sigma2
HVGA_new_lowrank = function(A, A_tilde, # Note: A, A_tilde seem unused in the core logic provided
                            X, X_tilde, alpha_weight,
                            mu0, Sigma0_list, beta_c0, sigma2_0, # Initial values
                            Gamma_base_precision, # Base precision (unscaled)
                            Gamma_base_diag_vec, L_base_factor, # Decomposed base prior parts
                            a, b, # IG prior parameters for sigma2
                            maxiter, tol, sigma2_tol = 1e-4, # Tolerances
                            mu_update_method = "newton", # "newton", "lbfgs", "cg"
                            cg_tol = 1e-5, cg_maxiter = 100,
                            lbfgs_maxit = 100, lbfgs_factr = 1e7
){
  
  if (!mu_update_method %in% c("newton", "lbfgs", "cg")) {
    stop("mu_update_method must be 'newton', 'lbfgs', or 'cg'")
  }
  if (sigma2_0 <= 0) stop("Initial sigma2_0 must be positive.")
  if (a <= 0 || b <= 0) warning("IG parameters a, b should be positive.")
  
  pb = txtProgressBar(style = 3)
  
  # Initialize state
  mu_k = mu0
  Sigma_k_list = Sigma0_list # Includes diag, W, L (should be L_base_factor)
  # # Ensure L stored in Sigma0_list is L_base_factor
  # if(!identical(Sigma0_list$L, L_base_factor)) {
  #   warning("Sigma0_list$L does not match L_base_factor. Storing L_base_factor in Sigma_k_list.")
  #   Sigma_k_list$L <- L_base_factor
  # }
  beta_c_k = beta_c0
  sigma2_k = sigma2_0
  # Precompute log determinant of Gamma_base_precision
  log_det_gamma_base = tryCatch({
    Matrix::determinant(Gamma_base_precision, logarithm = TRUE)$modulus[1]
  }, error = function(e){
    warning("Could not compute log determinant of Gamma_base_precision. Setting to 0.", call.=FALSE)
    # Handle error appropriately, maybe stop or use 0 if determinant doesn't matter for relative ELBO
    0
  })
  
  
  # History tracking
  dmu_norm = c(); dSigma_norm_approx = c(); dbeta_norm = c(); F_ELBO_vec = c()
  sigma2_hist = c(sigma2_k)
  beta_c_list = list(); beta_c_list[[1]] = beta_c0
  
  # Get dimensions and prior rank
  N = dim(X)[1]; n = dim(X_tilde)[1]; N_sum = n + N
  m = ncol(X) # Number of covariates
  r_prior = ncol(L_base_factor)
  I_r = Matrix::Diagonal(r_prior)
  
  # Precompute constants
  tA_eN_vec = c(rep(0, n), rep(1, N)) # Maps to mu_N part
  tX_eN = Matrix::crossprod(X, matrix(1, nrow=N, ncol=1))
  
  # Calculate initial ELBO
  # Ensure F_ELBO_lowrank uses L_base_factor from its argument, not from Sigma_k_list if different
  # Calculate initial ELBO
  F_ELBO_vec[1] = F_ELBO_lowrank(X, X_tilde, alpha_weight, mu_k, Sigma_k_list, beta_c_k, sigma2_k,
                                 Gamma_base_precision, Gamma_base_diag_vec, L_base_factor,
                                 log_det_gamma_base, # Pass precomputed value
                                 a, b)
  
  ELBO_tol = tol[1]
  beta_tol = tol[2]
  star_time = Sys.time()
  
  for (i in 1:maxiter) {
    # Store previous state for norm calculation
    mu_old = mu_k
    Sigma_k_diag_old = Sigma_k_list$diag
    Sigma_k_W_old = Sigma_k_list$W
    beta_c_old = beta_c_k
    sigma2_old = sigma2_k
    
    # --- 1. Update mu (using sigma2_k) ---
    if (mu_update_method == "newton") {
      # Requires Hessian = H_lik + H_prior = block_diag(ew,0) + (1/sigma2)*Gamma_base
      exp_w_vec_k = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                            L_base_factor, sigma2_k) # Add L_base_factor, sigma2_k
      
      # Check for Inf or NaN in exp_w_vec_k
      if (any(!is.finite(exp_w_vec_k))) {
        warning("Non-finite values found in exp_w_vec_k. Check calculations.")
        # Handle this case, maybe clamp values or stop
        exp_w_vec_k[!is.finite(exp_w_vec_k)] <- .Machine$double.xmax # Or some large reasonable value / stop
      }
      # Clamp very small values to avoid numerical zero? (Use with caution)
      # exp_w_vec_k <- pmax(exp_w_vec_k, .Machine$double.eps) 
      
      # Gradient of positive ELBO (f_mu)
      f_mu_lik_term2 = matrix(c(exp_w_vec_k, rep(0, N)), ncol=1)
      f_mu_lik_term3 = matrix(tA_eN_vec, ncol=1)
      f_mu_prior_term = -(1 / sigma2_k) * (Gamma_base_precision %*% mu_k)
      f_mu = f_mu_lik_term3 - f_mu_lik_term2 + f_mu_prior_term
      
      # Hessian of positive ELBO (grad.f_mu)
      diag_update_mat = Matrix::Diagonal(n = n, x = as.vector(exp_w_vec_k))
      block_diag_update = Matrix::bdiag(diag_update_mat, Matrix::Matrix(0, nrow=N, ncol=N, sparse=TRUE))
      grad.f_mu_unreg = block_diag_update + (1 / sigma2_k) * Gamma_base_precision # H = H_lik + H_prior
      
      # --- Add Regularization ---
      lambda_reg = 1e-6 # Regularization parameter, adjust if needed (e.g., 1e-8, 1e-5)
      N_sum_local = nrow(grad.f_mu_unreg) # Get dimension locally
      if (is(grad.f_mu_unreg, "sparseMatrix")) {
        grad.f_mu = grad.f_mu_unreg + lambda_reg * Matrix::Diagonal(N_sum_local)
      } else {
        grad.f_mu = grad.f_mu_unreg
        diag(grad.f_mu) = diag(grad.f_mu) + lambda_reg # Add to diagonal for dense matrix
      }
      # --- End Regularization ---
      
      # Newton step: H * d = -f_mu
      dmu = tryCatch({
        Matrix::solve(grad.f_mu, -f_mu, sparse = is(grad.f_mu, "sparseMatrix"), tol = .Machine$double.eps)
      }, error = function(e) {
        warning(paste("Matrix::solve failed even after regularization:", e$message))
        # Return a zero step or gradient descent step as fallback
        # return(matrix(0, nrow=length(mu_k), ncol=1)) # Zero step
        return(- (f_mu * 1e-4)) # Small gradient step (adjust step size)
      })
      
      mu_new = mu_k + 0.1*dmu # Solve H*d = -f_mu, then mu_new = mu_k + d
      
    } else if (mu_update_method == "lbfgs") {
      mu_k_vec = as.vector(mu_k)
      optim_result = tryCatch({
        stats::optim(par = mu_k_vec, fn = NegELBO_mu_lowrank, gr = grad_NegELBO_mu_lowrank,
                     method = "L-BFGS-B", control = list(maxit = lbfgs_maxit, factr = lbfgs_factr),
                     # Pass ALL necessary arguments for fn and gr:
                     Sigma_k_list = Sigma_k_list, beta_c_k = beta_c_k, sigma2_k = sigma2_k,
                     Gamma_base_precision = Gamma_base_precision, X_tilde = X_tilde,
                     alpha_weight = alpha_weight, n = n, N = N, tA_eN_vec = tA_eN_vec,
                     L_base_factor = L_base_factor) # Pass L_base_factor
      }, error = function(e) {
        warning("optim (L-BFGS) failed: ", e$message); list(par = mu_k_vec, convergence = -1) })
      if (optim_result$convergence != 0) {
        warning(paste("L-BFGS for mu did not converge. Code:", optim_result$convergence)) }
      mu_new = matrix(optim_result$par, ncol=1)
      
    } else if (mu_update_method == "cg") {
      # Solve Hessian(NegELBO) * d = -Gradient(NegELBO)
      # Hessian(NegELBO) = H_lik + (1/sigma2)*Gamma_base
      # Gradient(NegELBO) = -f_mu (where f_mu = grad(ELBO))
      # Target: (H_lik + (1/sigma2)*G) * d = f_mu
      
      exp_w_vec_k_cg = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                               L_base_factor, sigma2_k)
      f_mu_lik_term2_cg = matrix(c(exp_w_vec_k_cg, rep(0, N)), ncol=1)
      f_mu_lik_term3_cg = matrix(tA_eN_vec, ncol=1)
      f_mu_prior_term_cg = -(1 / sigma2_k) * (Gamma_base_precision %*% mu_k)
      f_mu_cg = f_mu_lik_term3_cg - f_mu_lik_term2_cg + f_mu_prior_term_cg
      
      b_cg = f_mu_cg # Right hand side is grad(ELBO)
      dmu0 = matrix(0, nrow = N_sum, ncol = 1) # Initial guess for step d
      
      cg_result = cg_solver(Ax_func = HessianVecProd_mu_lowrank, # HessianVecProd calculates (H_lik + H_prior_scaled)*v
                            b = b_cg, x0 = dmu0, tol = cg_tol, maxiter = cg_maxiter,
                            # Pass ALL necessary arguments for Ax_func:
                            mu_k = mu_k, Sigma_k_list = Sigma_k_list, beta_c_k = beta_c_k,
                            sigma2_k = sigma2_k, Gamma_base_precision = Gamma_base_precision,
                            X_tilde = X_tilde, alpha_weight = alpha_weight, n = n, N = N,
                            L_base_factor = L_base_factor) # Pass L_base_factor
      
      dmu = cg_result$x # Solution d = H^{-1} f_mu
      mu_new = mu_k + dmu # Update mu_new = mu_k + d
    }
    
    mu_k = mu_new # Accept update
    dmu_norm = append(dmu_norm, norm(mu_k - mu_old, "F"), length(dmu_norm))
    
    
    # --- 2. Update Sigma (using updated mu_k, current sigma2_k) ---
    # Uses Woodbury update based on Sigma^{-1} = H = H_lik + H_prior
    # H = block_diag(exp_w_vec, 0) + (1/sigma2_k) * Gamma_base
    # H = block_diag(...) + (1/sigma2_k) * (Diag(G_diag) + L L^T)
    # H = Diag( H_diag' ) + (1/sigma2_k) * L L^T
    # where H_diag' = block_diag(...)_diag + (1/sigma2_k)*G_diag
    
    exp_w_vec_mu_updated = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                                   L_base_factor, sigma2_k) # Pass args
    
    diag_lik_update_vec = c(as.vector(exp_w_vec_mu_updated), rep(0, N))
    H_tilde_diag_vec = diag_lik_update_vec + (1/sigma2_k) * Gamma_base_diag_vec
    H_tilde_diag_vec = pmax(H_tilde_diag_vec, .Machine$double.eps) # Ensure positivity
    
    D_inv_vec = 1 / H_tilde_diag_vec
    
    # M = (L')^T D_inv L' where L' = L_base / sqrt(sigma2_k)
    M = (1/sigma2_k) * Matrix::crossprod(L_base_factor, D_inv_vec * L_base_factor)
    
    # Middle_inv = solve(I + M)
    Middle_inv = tryCatch({
      Matrix::solve(I_r + M)
    }, error = function(e){
      warning("Matrix solve failed for (I+M) in exp_WeightedVec_lowrank. Check M's condition number. Using pseudo-inverse.", call. = FALSE)
      # Fallback or error handling, e.g., using MASS::ginv or stopping
      MASS::ginv(as.matrix(I_r + M)) # Ensure MASS is loaded or handle differently
    })
    
    # W component for the structure: Sigma = D_inv - W_new L'^T D_inv
    # W_new = D_inv L' Middle_inv = D_inv (L_base/sqrt(s2)) Middle_inv
    Term1 = (D_inv_vec / sqrt(sigma2_k)) * L_base_factor # D_inv * L'
    W_calc = Term1 %*% Middle_inv
    
    # Store the new Sigma state
    # We store D_inv, W_calc, and L_base_factor. L' is implicitly defined.
    Sigma_new_list = list(diag = D_inv_vec, W = W_calc, L = L_base_factor) # Store L_base
    
    # Approximate Norm Change
    dSigma_diag_norm = norm(Sigma_new_list$diag - Sigma_k_diag_old, "2")
    dSigma_W_norm = norm(Sigma_new_list$W - Sigma_k_W_old, "F") # Compare W directly
    dSigma_norm_approx_val = sqrt(dSigma_diag_norm^2 + dSigma_W_norm^2)
    dSigma_norm_approx = append(dSigma_norm_approx, dSigma_norm_approx_val, length(dSigma_norm_approx))
    
    Sigma_k_list = Sigma_new_list # Accept update
    
    
    # --- 3. Update beta (using updated mu_k, Sigma_k_list, current sigma2_k) ---
    # Beta update depends on E[exp(eta)] which uses mu_k and diag(Sigma_k)
    # Use the approximation for diag(Sigma_k) ~ Sigma_k_list$diag (D_inv)
    exp_w_vec_sigma_updated = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                                      L_base_factor, sigma2_k) # Pass args
    
    weights_vec = as.numeric(exp_w_vec_sigma_updated)
    if (length(weights_vec) != nrow(X_tilde)) stop("Dim mismatch weights_vec")
    
    # Hessian of NegELBO w.r.t beta = t(X_tilde) %*% diag(weights) %*% X_tilde
    # grad.f_beta = Matrix::crossprod(X_tilde * sqrt(weights_vec)) # More stable calculation
    # grad.f_beta = Matrix::crossprod(X_tilde, weights_vec * X_tilde) # Equivalent but less stable
    # Check dimensions carefully
    if (m == 1) {
      grad.f_beta_raw = matrix(sum(weights_vec * X_tilde^2), 1, 1)
      f_beta = matrix(sum(weights_vec * X_tilde), 1, 1) - tX_eN # grad(NegELBO) = -f_beta
    } else {
      Y_weighted = weights_vec * X_tilde # Element-wise product
      grad.f_beta_raw = Matrix::crossprod(X_tilde, Y_weighted) # Hessian of NegELBO w.r.t beta
      # Gradient of positive ELBO w.r.t beta = t(X) 1_N - t(X_tilde) weights
      f_beta = tX_eN - Matrix::crossprod(X_tilde, weights_vec) # grad(NegELBO) = -f_beta
    }
    
    # --- Regularization ---
    lambda_reg = 1e-6 # Small regularization parameter, adjust if needed
    # Ensure grad.f_beta_raw is a base R matrix for diag<-
    if (is(grad.f_beta_raw, "sparseMatrix")) {
      grad.f_beta_reg = grad.f_beta_raw + lambda_reg * Matrix::Diagonal(m)
    } else {
      grad.f_beta_reg = as.matrix(grad.f_beta_raw) # Ensure it's a base matrix if not sparse
      diag(grad.f_beta_reg) = diag(grad.f_beta_reg) + lambda_reg
    }
    # --- End Regularization ---
    
    
    # Solve for dbeta_c: grad.f_beta_reg * dbeta = -f_beta
    dbeta_c = tryCatch({
      Matrix::solve(grad.f_beta_raw, -f_beta, sparse = FALSE, tol = .Machine$double.eps)
    }, error = function(e) {
      warning(paste("Iteration", i, ": solve still failed after regularization:", e$message))
      # Return a zero step or handle differently if solve still fails
      matrix(0, nrow = m, ncol = 1)
    })
    
    beta_c_new = beta_c_k + dbeta_c
    
    dbeta_norm = append(dbeta_norm, norm(dbeta_c, "F"), length(dbeta_norm))
    beta_c_k = beta_c_new # Accept update
    beta_c_list[[length(beta_c_list)+1]] = beta_c_k
    
    
    # --- 4. Update sigma^2 (using updated mu_k, Sigma_k_list, beta_k) ---
    # Formula: sigma2_{k+1}^{-1} = (n + N + 2a + 2) / ( mu_k^T G_base mu_k + tr(G_base Sigma_k) - 2b )
    # Numerator calculation (Note: Eq 13 uses n+N+2a-2. Ensure 'a' corresponds.)
    numerator = n + N + 2*a + 2
    
    # Denominator calculation: E_q[mu^T Gamma_base mu] + 2b
    muT_G_mu_k = as.numeric(Matrix::crossprod(mu_k, Gamma_base_precision %*% mu_k))
    tr_G_Sigma_k = trace_Gamma_Sigma_lowrank(Gamma_base_diag_vec, L_base_factor, Sigma_k_list, sigma2_k)
    # Note: trace calculation is approximate.
    
    # Check for NA result from trace
    if (is.na(tr_G_Sigma_k)) {
      warning(paste("trace_Gamma_Sigma failed during sigma2 update at iteration", i, ". Skipping sigma2 update."))
      sigma2_new = sigma2_k # Keep old value
    } else {
      # Original denominator calculation (check signs based on Eq 13 vs derivation)
      # Derivation: sigma2 = ( muT G mu + tr(G Sigma) - 2b ) / (N_sum + 2a + 2)
      # Eq 13: sigma = sqrt( [ muT G mu + tr(G Sigma) - 2b ] / (n+N+2a+2) )
      # Seems consistent. The code calculates sigma2_new = denominator / numerator
      numerator = n + N + 2*a + 2
      denominator = muT_G_mu_k + tr_G_Sigma_k - 2*b # Match derivation/Eq 13 numerator
      
      if(denominator <= 0){
        warning(paste("Denominator for sigma2 update is non-positive (", denominator, ") at iteration", i, ". Clamping sigma2."))
        sigma2_new = sigma2_k # Keep previous value or clamp
        # sigma2_new = max(sigma2_k * 0.1, .Machine$double.eps) # Alternative: Decrease but keep positive
      } else {
        sigma2_new = denominator / numerator
      }
      # Clamp sigma2 just in case
      sigma2_new = max(.Machine$double.eps, sigma2_new)
    }
    sigma2_k = sigma2_new # Accept update
    
    sigma2_hist = append(sigma2_hist, sigma2_k, length(sigma2_hist))
    
    
    # --- Calculate ELBO and Check Convergence ---
    F_ELBO_current = F_ELBO_lowrank(X, X_tilde, alpha_weight, mu_k, Sigma_k_list, beta_c_k, sigma2_k,
                                    Gamma_base_precision, Gamma_base_diag_vec, L_base_factor,
                                    log_det_gamma_base, # Pass precomputed value
                                    a, b)
    F_ELBO_vec = append(F_ELBO_vec, F_ELBO_current, length(F_ELBO_vec))
    
    # Convergence Check
    # Use ELBO change from previous iteration's *final* ELBO
    elbo_diff = abs(F_ELBO_current - F_ELBO_vec[length(F_ELBO_vec) - 1])
    beta_diff = max(abs(beta_c_k - beta_c_old))
    sigma2_diff = abs(sigma2_k - sigma2_old)
    
    if(!is.na(elbo_diff) && elbo_diff <= ELBO_tol &&
       beta_diff <= beta_tol &&
       sigma2_diff <= sigma2_tol) {
      message("\nConvergence criteria met.")
      break
    }
    
    setTxtProgressBar(pb, i/maxiter)
  } # End of main loop
  
  end_time = Sys.time()
  close(pb)
  run_time = as.numeric(difftime(end_time, star_time, units = "secs"))
  
  # Return results including sigma2
  result = list(mu_k, dmu_norm, Sigma_k_list, dSigma_norm_approx, beta_c_list, dbeta_norm,
                sigma2_k, sigma2_hist, # Add sigma2 results
                F_ELBO_vec, run_time, mu_update_method)
  names(result) = c("mu", "dmu_norm", "Sigma_lowrank", "dSigma_norm_approx", "beta_hist", "dbeta_norm",
                    "sigma2", "sigma2_hist", # Add sigma2 names
                    "ELBO", "running_time", "mu_update_method")
  return(result)
}

# Main algorithm incorporating sigma2 with inner loops
HVGA_new_lowrank2 = function(A, A_tilde, # Note: A, A_tilde seem unused
                            X, X_tilde, alpha_weight,
                            mu0, Sigma0_list, beta_c0, sigma2_0, # Initial values
                            Gamma_base_precision, # Base precision (unscaled)
                            Gamma_base_diag_vec, L_base_factor, # Decomposed base prior parts
                            a, b, # IG prior parameters for sigma2
                            maxiter, tol, sigma2_tol = 1e-4, # Outer loop tolerances
                            inner_maxiter = 10, # Max iterations for inner loops
                            inner_tol_mu = 1e-5,  # Inner tolerance for mu update
                            inner_tol_sigma = 1e-5, # Inner tolerance for Sigma update
                            inner_tol_beta = 1e-5, # Inner tolerance for beta update
                            mu_update_method = "newton", # "newton", "lbfgs", "cg"
                            cg_tol = 1e-5, cg_maxiter = 100, cg_step, 
                            lbfgs_maxit = 100, lbfgs_factr = 1e7
){
  
  if (!mu_update_method %in% c("newton", "lbfgs", "cg")) {
    stop("mu_update_method must be 'newton', 'lbfgs', or 'cg'")
  }
  if (sigma2_0 <= 0) stop("Initial sigma2_0 must be positive.")
  if (a <= 0 || b <= 0) warning("IG parameters a, b should be positive.")
  if (!requireNamespace("MASS", quietly = TRUE) && mu_update_method == "newton") {
    warning("Package 'MASS' needed for ginv fallback in matrix solves. Please install it.", call. = FALSE)
  }
  
  
  pb = txtProgressBar(style = 3)
  
  # Initialize state
  mu_k = mu0
  Sigma_k_list = Sigma0_list # Includes diag, W, L (should be L_base_factor)
  beta_c_k = beta_c0
  sigma2_k = sigma2_0
  
  # Precompute log determinant of Gamma_base_precision
  log_det_gamma_base = tryCatch({
    Matrix::determinant(Gamma_base_precision, logarithm = TRUE)$modulus[1]
  }, error = function(e){
    warning("Could not compute log determinant of Gamma_base_precision. Setting to 0.", call.=FALSE)
    0 # Handle error appropriately
  })
  
  # History tracking (for outer loop changes)
  dmu_norm_outer = c(); dSigma_norm_approx = c(); dbeta_norm_outer = c(); F_ELBO_vec = c()
  sigma2_hist = c(sigma2_k)
  beta_c_list = list(); beta_c_list[[1]] = beta_c0
  
  # Get dimensions and prior rank
  N = dim(X)[1]; n = dim(X_tilde)[1]; N_sum = n + N
  m = ncol(X) # Number of covariates
  r_prior = ncol(L_base_factor)
  I_r = Matrix::Diagonal(r_prior)
  
  # Precompute constants
  tA_eN_vec = c(rep(0, n), rep(1, N)) # Maps to mu_N part
  tX_eN = Matrix::crossprod(X, matrix(1, nrow=N, ncol=1))
  
  # Calculate initial ELBO
  F_ELBO_vec[1] = F_ELBO_lowrank(X, X_tilde, alpha_weight, mu_k, Sigma_k_list, beta_c_k, sigma2_k,
                                 Gamma_base_precision, Gamma_base_diag_vec, L_base_factor,
                                 log_det_gamma_base, a, b)
  
  ELBO_tol = tol[1] # Outer loop ELBO tolerance
  beta_tol = tol[2] # Outer loop beta tolerance
  star_time = Sys.time()
  
  # --- Outer Loop ---
  for (i in 1:maxiter) {
    # Store state at the START of the outer iteration for outer convergence check
    mu_outer_old = mu_k
    Sigma_k_diag_outer_old = Sigma_k_list$diag
    Sigma_k_W_outer_old = Sigma_k_list$W
    beta_c_outer_old = beta_c_k
    sigma2_outer_old = sigma2_k
    
    for (inner_iter_beta in 1:inner_maxiter) {
      beta_c_inner_old = beta_c_k
      
      # Beta update depends on LATEST mu_k, Sigma_k_list, and current sigma2_k
      exp_w_vec_sigma_updated = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                                        L_base_factor, sigma2_k)
      if (any(!is.finite(exp_w_vec_sigma_updated))) {
        warning(paste("Iter", i, "InnerBeta", inner_iter_beta, ": Non-finite exp_w_vec_sigma_updated. Clamping."), call.=FALSE); exp_w_vec_sigma_updated[!is.finite(exp_w_vec_sigma_updated)] <- .Machine$double.xmax }
      
      
      weights_vec = as.numeric(exp_w_vec_sigma_updated)
      if (length(weights_vec) != nrow(X_tilde)) stop("Dim mismatch weights_vec")
      
      # --- Newton Step for Beta ---
      if (m == 1) {
        grad.f_beta_raw = matrix(sum(weights_vec * X_tilde^2), 1, 1)
        # Gradient of positive ELBO w.r.t beta = t(X) 1_N - t(X_tilde) weights
        f_beta = tX_eN - matrix(sum(weights_vec * X_tilde), 1, 1)
      } else {
        Y_weighted = weights_vec * X_tilde
        grad.f_beta_raw = Matrix::crossprod(X_tilde, Y_weighted) # Hessian of NegELBO
        f_beta = tX_eN - Matrix::crossprod(X_tilde, weights_vec) # Gradient of positive ELBO
      }
      
      lambda_reg = 1e-6
      if (is(grad.f_beta_raw, "sparseMatrix")) {
        grad.f_beta_reg = grad.f_beta_raw + lambda_reg * Matrix::Diagonal(m)
      } else {
        grad.f_beta_reg = as.matrix(grad.f_beta_raw); diag(grad.f_beta_reg) = diag(grad.f_beta_reg) + lambda_reg
      }
      
      # Solve Hess(NegELBO) * dbeta = -grad(NegELBO) => grad.f_beta_reg * dbeta = f_beta
      # Newton step dbeta = Hess(NegELBO)^{-1} * f_beta
      dbeta_c = tryCatch({
        # Note: grad.f_beta_reg is Hessian of NegELBO, f_beta is grad(ELBO)
        # We want dbeta = - Hess(ELBO)^{-1} grad(ELBO)
        # Hess(ELBO) = - grad.f_beta_reg (approx, ignoring reg)
        # Solve grad.f_beta_reg * dbeta = f_beta
        Matrix::solve(grad.f_beta_raw, f_beta, sparse = FALSE, tol = .Machine$double.eps)
      }, error = function(e) {
        warning(paste("Iter", i, "InnerBeta", inner_iter_beta, ": Beta solve failed:", e$message), call.=FALSE); matrix(0, nrow = m, ncol = 1)
      })
      beta_step = cg_step[1]
      beta_c_new = beta_c_k + beta_step*dbeta_c # Newton update: beta_new = beta_old + step
      
      # Check inner convergence for beta
      dbeta_norm_inner = norm(dbeta_c, "F")
      beta_c_k = beta_c_new # Accept inner update
      
      if (dbeta_norm_inner < inner_tol_beta) {
        beta_c_list[[length(beta_c_list)+1]] = beta_c_k # Store converged beta for this outer iter
        break # Exit inner beta loop
      }
    } # End inner beta loop
    # If loop finished by maxiter, store last beta
    if(inner_iter_beta == inner_maxiter) {
      beta_c_list[[length(beta_c_list)+1]] = beta_c_k
    }
    
    # --- 1. Inner Loop: Update mu ---
    for (inner_iter_mu in 1:inner_maxiter) {
      mu_inner_old = mu_k # Store mu at start of inner mu iteration
      
      if (mu_update_method == "newton") {
        # --- Newton Step for Mu ---
        exp_w_vec_k = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                              L_base_factor, sigma2_k)
        if (any(!is.finite(exp_w_vec_k))) {
          warning(paste("Iter", i, "InnerMu", inner_iter_mu, ": Non-finite exp_w_vec_k. Clamping."), call.=FALSE); exp_w_vec_k[!is.finite(exp_w_vec_k)] <- .Machine$double.xmax }
        
        f_mu_lik_term2 = matrix(c(exp_w_vec_k, rep(0, N)), ncol=1)
        f_mu_lik_term3 = matrix(tA_eN_vec, ncol=1)
        f_mu_prior_term = -(1 / sigma2_k) * (Gamma_base_precision %*% mu_k)
        f_mu = f_mu_lik_term3 - f_mu_lik_term2 + f_mu_prior_term # grad(ELBO)
        
        diag_update_mat = Matrix::Diagonal(n = n, x = as.vector(exp_w_vec_k))
        block_diag_update = Matrix::bdiag(diag_update_mat, Matrix::Matrix(0, nrow=N, ncol=N, sparse=TRUE))
        grad.f_mu_unreg = block_diag_update + (1 / sigma2_k) * Gamma_base_precision # Hessian(ELBO) = H_lik + H_prior
        
        lambda_reg = 1e-6
        N_sum_local = nrow(grad.f_mu_unreg)
        if (is(grad.f_mu_unreg, "sparseMatrix")) {
          grad.f_mu = grad.f_mu_unreg + lambda_reg * Matrix::Diagonal(N_sum_local)
        } else {
          grad.f_mu = grad.f_mu_unreg; diag(grad.f_mu) = diag(grad.f_mu) + lambda_reg }
        
        dmu = tryCatch({ # H * d = -f_mu => Newton step d = -H^{-1} f_mu
          Matrix::solve(grad.f_mu, -f_mu, sparse = is(grad.f_mu, "sparseMatrix"), tol = .Machine$double.eps)
        }, error = function(e) {
          warning(paste("Iter", i, "InnerMu", inner_iter_mu, ": Mu Newton solve failed:", e$message), call.=FALSE); matrix(0, nrow=length(mu_k), ncol=1) })
        
        mu_new = mu_k + 0.01*dmu # Damped update step: mu_new = mu_k + step * dmu
        
      } else if (mu_update_method == "lbfgs") {
        # --- L-BFGS Step for Mu ---
        mu_k_vec = as.vector(mu_k)
        optim_result = tryCatch({
          stats::optim(par = mu_k_vec, fn = NegELBO_mu_lowrank, gr = grad_NegELBO_mu_lowrank,
                       method = "L-BFGS-B", control = list(maxit = 1, factr = lbfgs_factr), # NOTE: maxit=1 for single step inside inner loop
                       Sigma_k_list = Sigma_k_list, beta_c_k = beta_c_k, sigma2_k = sigma2_k,
                       Gamma_base_precision = Gamma_base_precision, X_tilde = X_tilde,
                       alpha_weight = alpha_weight, n = n, N = N, tA_eN_vec = tA_eN_vec,
                       L_base_factor = L_base_factor)
        }, error = function(e) {
          warning(paste("Iter", i, "InnerMu", inner_iter_mu, ": optim (L-BFGS) failed:", e$message), call.=FALSE); list(par = mu_k_vec, convergence = -1) })
        if (optim_result$convergence != 0) {
          warning(paste("Iter", i, "InnerMu", inner_iter_mu, ": L-BFGS step warning. Code:", optim_result$convergence), call.=FALSE) }
        mu_new = matrix(optim_result$par, ncol=1)
        
      } else if (mu_update_method == "cg") {
        # --- CG Step for Mu (solving H*d = f_mu for step d) ---
        exp_w_vec_k_cg = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                                 L_base_factor, sigma2_k)
        f_mu_lik_term2_cg = matrix(c(exp_w_vec_k_cg, rep(0, N)), ncol=1)
        f_mu_lik_term3_cg = matrix(tA_eN_vec, ncol=1)
        f_mu_prior_term_cg = -(1 / sigma2_k) * (Gamma_base_precision %*% mu_k)
        f_mu_cg = f_mu_lik_term3_cg - f_mu_lik_term2_cg + f_mu_prior_term_cg # grad(ELBO)
        
        b_cg = f_mu_cg # Right hand side for H d = b
        dmu0 = matrix(0, nrow = N_sum, ncol = 1) # Initial guess for step d
        
        cg_result = cg_solver(Ax_func = HessianVecProd_mu_lowrank, # Ax calculates H*v
                              b = b_cg, x0 = dmu0, tol = cg_tol, maxiter = cg_maxiter, # NOTE: maxiter=1 for single step
                              mu_k = mu_k, Sigma_k_list = Sigma_k_list, beta_c_k = beta_c_k,
                              sigma2_k = sigma2_k, Gamma_base_precision = Gamma_base_precision,
                              X_tilde = X_tilde, alpha_weight = alpha_weight, n = n, N = N,
                              L_base_factor = L_base_factor)
        if (!cg_result$converged && cg_maxiter > 1) { # Only warn if multiple CG iters were intended per step
          warning(paste("Iter", i, "InnerMu", inner_iter_mu, ": CG step did not converge."), call.=FALSE) }
        dmu = cg_result$x # Solution d = H^{-1} f_mu (approx)
        mu_step = cg_step[2]
        mu_new = mu_k + mu_step*dmu # Update mu_new = mu_k + d
        # mu_new = mu_k + dmu
        # # --- Backtracking Line Search ---
        # alpha = 1.0  # Start with full step
        # rho = 0.5    # Backtracking factor
        # c1 = 1e-4    # Armijo condition parameter
        # 
        # # Calculate gradient of Negative ELBO at mu_k
        # g_neg = -grad_NegELBO_mu_lowrank(mu_k, Sigma_k_list, beta_c_k, sigma2_k, Gamma_base_precision,
        #                                  X_tilde, alpha_weight, n, N, tA_eN_vec, L_base_factor)
        # g_neg = as.matrix(g_neg) # Ensure matrix form
        # 
        # # Calculate directional derivative grad(NegELBO)^T * dmu
        # dir_deriv = sum(g_neg * dmu)
        # 
        # # Check if dmu is a descent direction
        # if (dir_deriv >= 0) {
        #   warning(paste("Iteration", i, ": CG direction is not a descent direction (dir_deriv =", dir_deriv,"). Skipping mu update or using gradient step?"))
        #   # Option 1: Skip update
        #   mu_new = mu_k
        #   # Option 2: Use a small gradient step instead (less common with Newton-CG)
        #   # alpha_grad = 1e-4 # Or find via line search
        #   # mu_new = mu_k - alpha_grad * g_neg
        # } else {
        #   # Calculate current NegELBO
        #   neg_elbo_k = NegELBO_mu_lowrank(mu_k, Sigma_k_list, beta_c_k, sigma2_k, Gamma_base_precision,
        #                                   X_tilde, alpha_weight, n, N, tA_eN_vec, L_base_factor)
        #   
        #   max_ls_iter = 10 # Limit line search iterations
        #   ls_iter = 0
        #   while (ls_iter < max_ls_iter) {
        #     mu_trial = mu_k + alpha * dmu
        #     neg_elbo_trial = NegELBO_mu_lowrank(mu_trial, Sigma_k_list, beta_c_k, sigma2_k, Gamma_base_precision,
        #                                         X_tilde, alpha_weight, n, N, tA_eN_vec, L_base_factor)
        #     
        #     # Check Armijo condition
        #     if (neg_elbo_trial <= neg_elbo_k + c1 * alpha * dir_deriv) {
        #       break # Found suitable alpha
        #     }
        #     
        #     alpha = alpha * rho # Backtrack
        #     ls_iter = ls_iter + 1
        #     if (alpha < 1e-9){ # Avoid tiny steps
        #       warning(paste("Iteration", i, ": Line search alpha became too small."))
        #       break
        #     }
        #   }
        #   if(ls_iter == max_ls_iter) warning(paste("Iteration", i, ": Line search reached max iterations."))
        #   
        #   mu_new = mu_k + alpha * dmu # Update mu_new = mu_k + alpha * d
        # }
        # # --- End Line Search ---
      }
      
      # Check inner convergence for mu
      mu_inner_diff = norm(mu_new - mu_inner_old, "F")
      mu_k = mu_new # Accept inner update
      
      if (mu_inner_diff < inner_tol_mu) {
        break # Exit inner mu loop
      }
    } # End inner mu loop
    
    # --- 2. Update Sigma (using updated mu_k, current sigma2_k) ---
    # Uses Woodbury update based on Sigma^{-1} = H = H_lik + H_prior
    # H = block_diag(exp_w_vec, 0) + (1/sigma2_k) * Gamma_base
    # H = block_diag(...) + (1/sigma2_k) * (Diag(G_diag) + L L^T)
    # H = Diag( H_diag' ) + (1/sigma2_k) * L L^T
    # where H_diag' = block_diag(...)_diag + (1/sigma2_k)*G_diag
    Sigma_k_diag_old = Sigma_k_list$diag
    Sigma_k_W_old = Sigma_k_list$W
    exp_w_vec_mu_updated = exp_WeightedVec_lowrank(alpha_weight, mu_k, Sigma_k_list, beta_c_k, X_tilde,
                                                   L_base_factor, sigma2_k) # Pass args
    
    diag_lik_update_vec = c(as.vector(exp_w_vec_mu_updated), rep(0, N))
    H_tilde_diag_vec = diag_lik_update_vec + (1/sigma2_k) * Gamma_base_diag_vec
    H_tilde_diag_vec = pmax(H_tilde_diag_vec, .Machine$double.eps) # Ensure positivity
    
    D_inv_vec = 1 / H_tilde_diag_vec
    
    # M = (L')^T D_inv L' where L' = L_base / sqrt(sigma2_k)
    M = (1/sigma2_k) * Matrix::crossprod(L_base_factor, D_inv_vec * L_base_factor)
    
    # Middle_inv = solve(I + M)
    Middle_inv = tryCatch({
      Matrix::solve(I_r + M)
    }, error = function(e){
      warning("Matrix solve failed for (I+M) in exp_WeightedVec_lowrank. Check M's condition number. Using pseudo-inverse.", call. = FALSE)
      # Fallback or error handling, e.g., using MASS::ginv or stopping
      MASS::ginv(as.matrix(I_r + M)) # Ensure MASS is loaded or handle differently
    })
    
    # W component for the structure: Sigma = D_inv - W_new L'^T D_inv
    # W_new = D_inv L' Middle_inv = D_inv (L_base/sqrt(s2)) Middle_inv
    Term1 = (D_inv_vec / sqrt(sigma2_k)) * L_base_factor # D_inv * L'
    W_calc = Term1 %*% Middle_inv
    
    # Store the new Sigma state
    # We store D_inv, W_calc, and L_base_factor. L' is implicitly defined.
    Sigma_new_list = list(diag = D_inv_vec, W = W_calc, L = L_base_factor) # Store L_base
    
    # Approximate Norm Change
    dSigma_diag_norm = norm(Sigma_new_list$diag - Sigma_k_diag_old, "2")
    dSigma_W_norm = norm(Sigma_new_list$W - Sigma_k_W_old, "F") # Compare W directly
    dSigma_norm_approx_val = sqrt(dSigma_diag_norm^2 + dSigma_W_norm^2)
    dSigma_norm_approx = append(dSigma_norm_approx, dSigma_norm_approx_val, length(dSigma_norm_approx))
    
    Sigma_k_list = Sigma_new_list # Accept update
    
    
    
    # --- 4. Update sigma^2 (Once per outer iteration, using converged mu & Sigma) ---
    muT_G_mu_k = as.numeric(Matrix::crossprod(mu_k, Gamma_base_precision %*% mu_k))
    tr_G_Sigma_k = trace_Gamma_Sigma_lowrank(Gamma_base_diag_vec, L_base_factor, Sigma_k_list, sigma2_k)
    
    if (is.na(tr_G_Sigma_k)) {
      warning(paste("Iter", i, ": trace_Gamma_Sigma failed during sigma2 update. Skipping sigma2 update."), call.=FALSE)
      sigma2_new = sigma2_k # Keep old value
    } else {
      numerator_s2 = n + N + 2*a + 2
      denominator_s2 = muT_G_mu_k + tr_G_Sigma_k - 2*b
      
      if(denominator_s2 <= 0){
        warning(paste("Iter", i, ": Denominator for sigma2 update non-positive (", denominator_s2, "). Clamping sigma2."), call.=FALSE)
        # Clamp or keep old:
        # sigma2_new = sigma2_k
        sigma2_new = max(sigma2_k * 0.1, .Machine$double.eps) # Cautious decrease
      } else {
        sigma2_new = denominator_s2 / numerator_s2
      }
      sigma2_new = max(.Machine$double.eps, sigma2_new) # Ensure positivity
    }
    sigma2_k = sigma2_new # Accept update
    sigma2_hist = append(sigma2_hist, sigma2_k, length(sigma2_hist))
    
    
    # --- Calculate Outer Loop Changes and ELBO ---
    dmu_norm_outer = append(dmu_norm_outer, norm(mu_k - mu_outer_old, "F"), length(dmu_norm_outer))
    
    dSigma_diag_norm_outer = norm(Sigma_k_list$diag - Sigma_k_diag_outer_old, "2")
    if (is.matrix(Sigma_k_list$W) && is.matrix(Sigma_k_W_outer_old)) {
      dSigma_W_norm_outer = norm(Sigma_k_list$W - Sigma_k_W_outer_old, "F")
    } else {
      dSigma_W_norm_outer = ifelse(is.null(Sigma_k_list$W) && is.null(Sigma_k_W_outer_old), 0, Inf)
    }
    
    dbeta_norm_outer = append(dbeta_norm_outer, norm(beta_c_k - beta_c_outer_old, "F"), length(dbeta_norm_outer)) # Overall change in beta
    
    F_ELBO_current = F_ELBO_lowrank(X, X_tilde, alpha_weight, mu_k, Sigma_k_list, beta_c_k, sigma2_k,
                                    Gamma_base_precision, Gamma_base_diag_vec, L_base_factor,
                                    log_det_gamma_base, a, b)
    F_ELBO_vec = append(F_ELBO_vec, F_ELBO_current, length(F_ELBO_vec))
    
    # --- Outer Convergence Check ---
    # Use ELBO change from previous *outer* iteration's ELBO
    elbo_diff = abs(F_ELBO_current - F_ELBO_vec[length(F_ELBO_vec) - 1])
    # Use overall change in parameters from start to end of outer iteration
    beta_diff_outer = norm(beta_c_k - beta_c_outer_old, "F") # Or max(abs(...))
    sigma2_diff_outer = abs(sigma2_k - sigma2_outer_old)
    
    if(!is.na(elbo_diff) && elbo_diff <= ELBO_tol &&
       beta_diff_outer <= beta_tol &&
       sigma2_diff_outer <= sigma2_tol) {
      message("\nOuter loop convergence criteria met.")
      break
    }
    
    setTxtProgressBar(pb, i/maxiter)
  } # --- End of Outer loop ---
  
  end_time = Sys.time()
  close(pb)
  run_time = as.numeric(difftime(end_time, star_time, units = "secs"))
  
  # Return results including sigma2
  result = list(mu_k, dmu_norm_outer, Sigma_k_list, dSigma_norm_approx, beta_c_list, dbeta_norm_outer,
                sigma2_k, sigma2_hist, # Add sigma2 results
                F_ELBO_vec, run_time, mu_update_method)
  names(result) = c("mu", "dmu_norm", "Sigma_lowrank", "dSigma_norm_approx", "beta_hist", "dbeta_norm",
                    "sigma2", "sigma2_hist", # Add sigma2 names
                    "ELBO", "running_time", "mu_update_method")
  return(result)
}


# Example Usage (requires setup of inputs like decompose_gamma)

#' Decompose a sparse precision matrix into Diagonal + Low Rank approximation
#' (Copied from thought process - adjust inputs/outputs for base precision)
decompose_gamma <- function(Gamma_base_precision, rank_r, tol_eigen = sqrt(.Machine$double.eps)) {
  # ... (implementation as in thought process, ensuring output names are
  #      Gamma_base_diag_vec and L_base_factor) ...
  if (!is(Gamma_base_precision, "symmetricMatrix")) {
    if (!Matrix::isSymmetric(Gamma_base_precision)) stop("Input Gamma_base_precision must be symmetric.")
    Gamma_base_precision <- Matrix::forceSymmetric(Gamma_base_precision)
  }
  N_sum <- nrow(Gamma_base_precision)
  if(N_sum == 0) stop("Gamma_base_precision is empty.")
  
  message("Step 1: Extracting diagonal...")
  Gamma_base_diag_vec <- Matrix::diag(Gamma_base_precision)
  if (any(Gamma_base_diag_vec <= tol_eigen)) {
    warning("Some diagonal elements <= tol_eigen. Clamping."); Gamma_base_diag_vec[Gamma_base_diag_vec <= tol_eigen] <- tol_eigen }
  Gamma_diag_mat <- Matrix::Diagonal(n = N_sum, x = Gamma_base_diag_vec)
  
  message("Step 2: Calculating off-diagonal remainder matrix R...")
  R_offdiag <- Gamma_base_precision - Gamma_diag_mat
  if (!is(R_offdiag, "symmetricMatrix")) R_offdiag <- Matrix::forceSymmetric(R_offdiag)
  R_offdiag_for_eigs <- as(R_offdiag, "dgCMatrix") # Ensure dgCMatrix for RSpectra
  
  message(paste("Step 3: Computing partial eigen decomp (k =", rank_r, ")..."))
  tryCatch({
    eigen_decomp <- RSpectra::eigs_sym(R_offdiag_for_eigs, k = rank_r, which = "LM", opts = list(retvec = TRUE, ncv = min(N_sum, max(2 * rank_r + 1, 20))))
  }, error = function(e){ stop(paste("RSpectra::eigs_sym failed:", e$message)) })
  
  message("Step 4: Constructing L_base_factor...")
  eigenvalues <- eigen_decomp$values; eigenvectors <- eigen_decomp$vectors
  positive_idx <- which(eigenvalues > tol_eigen)
  
  if (length(positive_idx) == 0) {
    warning("No positive eigenvalues found. L_base_factor will be empty."); L_base_factor <- matrix(0.0, nrow = N_sum, ncol = 0); r_actual <- 0
  } else {
    pos_eigenvalues <- eigenvalues[positive_idx]; pos_eigenvectors <- eigenvectors[, positive_idx, drop = FALSE]
    order_idx <- order(pos_eigenvalues, decreasing = TRUE)
    pos_eigenvalues <- pos_eigenvalues[order_idx]; pos_eigenvectors <- pos_eigenvectors[, order_idx, drop = FALSE]
    L_base_factor <- sweep(pos_eigenvectors, MARGIN = 2, STATS = sqrt(pos_eigenvalues), FUN = `*`)
    r_actual <- ncol(L_base_factor); message(paste("  Actual rank r_actual =", r_actual))
  }
  message("Decomposition complete.")
  return(list(
    Gamma_base_diag_vec = Gamma_base_diag_vec,
    L_base_factor = L_base_factor,
    r_actual = r_actual
  ))
}

