
source("./optimization_hyper.R")
getwd()
load("./data/nngp_prior_full.RData")

# -----------------------------------------------------------
# optimization_hyper.R: optimization4.R + hyperparameter
nN = n+N
beta_ini = matrix(c(log(N)-log((win$xrange[2]-win$xrange[1])*(win$yrange[2]-win$yrange[1])), 1,1), 3, 1)
# beta_ini = matrix(c(1, 1, 1), 3, 1)
mu_ini = matrix(0, nN, 1)
desired_rank = 50 
decomposition_result <- decompose_gamma(Gamma_precision, rank_r = desired_rank)
# 使用正确的名称提取结果
Gamma_base_diag_vec <- decomposition_result$Gamma_base_diag_vec 
L_base_factor       <- decomposition_result$L_base_factor      
r_actual            <- decomposition_result$r_actual



if (!exists("N_sum")) N_sum <- n + N 
if (is.null(L_base_factor)) {
  stop("L_base_factor is NULL after decomposition. Check decompose_gamma.")
} else if (ncol(L_base_factor) != r_actual) {
  warning("ncol(L_base_factor) does not match r_actual. Using ncol.")
  r_prior_dim <- ncol(L_base_factor) # Use actual dimension from matrix
} else {
  r_prior_dim <- r_actual # Use the value from decomposition result
}
Sigma_ini= list(diag = 1 / Gamma_base_diag_vec,             
                   W = matrix(0.0, nrow=N_sum, ncol=r_prior_dim), 
                   L = L_base_factor)     
sigma2_ini = 1
a = 0.01
b = 0.01

VINNGP_cg_lowrank2 = HVGA_new_lowrank2(A, A_tilde, 
                                       X, X_tilde, alpha_weight,
                                       mu_ini, Sigma_ini, beta_ini, sigma2_ini, 
                                       Gamma_precision, # Base precision (unscaled)
                                       Gamma_base_diag_vec, L_base_factor, # Decomposed base prior parts
                                       a, b, # IG prior parameters for sigma2
                                       maxiter = 50, tol = c(5,0.05), sigma2_tol = 1e-4, # Outer loop tolerances
                                       inner_maxiter = 2, # Max iterations for inner loops
                                       inner_tol_mu = 1e-2,  # Inner tolerance for mu update
                                       inner_tol_sigma = 1e-2, # Inner tolerance for Sigma update
                                       inner_tol_beta = 1e-2, # Inner tolerance for beta update
                                       mu_update_method = "cg",
                                       cg_tol = 1e-6, cg_maxiter = 1, cg_step = c(0.01, 0.05))

# result
beta_vinngp = VINNGP_cg_lowrank2$beta_hist[[length(VINNGP_cg_lowrank2$beta_hist)]]
mu_vinngp_n = VINNGP_cg_lowrank2$mu[1:n]
mu_vinngp_pred = mu_vinngp_n
mu_vinngp_N = VINNGP_cg_lowrank2$mu[(n+1):(n+N)]
sigma2_vinngp = VINNGP_cg_lowrank2$sigma2

Sigma_vinngp_list = VINNGP_cg_lowrank2$Sigma_lowrank
L_base_factor = Sigma_vinngp_list$L
r_prior = ncol(L_base_factor)
I_r = Matrix::Diagonal(r_prior)
D_inv_vec = Sigma_vinngp_list$diag # = (H_tilde_diag)^{-1}

# Calculate diagonal of Sigma_k = D_inv - D_inv L' (I+M)^{-1} L'^T D_inv
# M = L'^T D_inv L' = (1/sigma2_k) * L_base^T D_inv L_base
M = (1/sigma2_vinngp) * Matrix::crossprod(L_base_factor, D_inv_vec * L_base_factor)

# Middle_inv = solve(I + M)
Middle_inv = tryCatch({
  Matrix::solve(I_r + M)
}, error = function(e){
  warning("Matrix solve failed for (I+M) in exp_WeightedVec_lowrank. Check M's condition number. Using pseudo-inverse.", call. = FALSE)
  # Fallback or error handling, e.g., using MASS::ginv or stopping
  MASS::ginv(as.matrix(I_r + M)) # Ensure MASS is loaded or handle differently
})

# L' = L_base / sqrt(sigma2_vinngp)
L_prime = (1/sqrt(sigma2_vinngp)) * L_base_factor

# Precompute V = L' %*% Middle_inv
V = L_prime %*% Middle_inv  # N_sum x r

# Calculate diag(Sigma_k)_i = D_inv_i - D_inv_i^2 * sum_j V[i,j] * L_prime[i,j]
# More efficiently: D_inv_i - D_inv_i^2 * rowSums(V * L_prime)
Correction_term = D_inv_vec^2 * rowSums(V * L_prime)
Sigma_k_diag_elems = D_inv_vec - Correction_term

# Check for negative variances (shouldn't happen if stable)
if(any(Sigma_k_diag_elems <= 0)){
  warning("Non-positive diagonal elements calculated for Sigma_k in exp_WeightedVec_lowrank. Clamping.")
  Sigma_k_diag_elems = pmax(Sigma_k_diag_elems, .Machine$double.eps)
}

Sigma_vinngp_nn_diag = Sigma_k_diag_elems[1:n]

# llk_VISPDE_full = as.numeric(sum(as.vector(X_cov1%*%beta_opt + B_mat%*%mu_opt)) - sum(alpha_weight*exp(as.vector(X_cov2%*%beta_opt + A_mat%*%mu_opt))))
llk_vinngp_full = as.numeric(sum(as.vector(X%*%beta_vinngp + mu_vinngp_N)) - sum(alpha_weight*exp(as.vector(X_tilde%*%beta_vinngp + mu_vinngp_n))))
llk_vinngp_full_n = as.numeric(sum(as.vector(X_tilde%*%beta_vinngp + mu_vinngp_n)) - sum(alpha_weight*exp(as.vector(X_tilde%*%beta_vinngp + mu_vinngp_n))))
# llk_vinngp_full_pred = as.numeric(sum(as.vector(X_test%*%beta_vinngp +mu_vinngp_pred)) - sum(alpha_weight*exp(as.vector(X_tilde%*%beta_vinngp + mu_vinngp_n))))

#-------------------------------------------- INLA -------------------------------------------------
y0 = x0 = seq(win$xrange[1], win$xrange[2], length = spatstat.options()$npixel)
spde_sim = inla.spde2.pcmatern(
  mesh = mesh_full, 
  alpha = 2,
  prior.range = c(0.05, 0.01), ### P(practic.range<0.05)=0.01
  prior.sigma = c(1, 0.01) ### P(sigma>1)=0.01
)
y.pp <- dat_observed$pres
e.pp <- dat_observed$quad.size
Amat <- inla.spde.make.A(mesh_full, as.matrix(dat_observed[ , c("x", "y")]))
stk.pp <- inla.stack(
  data = list(y = y.pp, e = e.pp),
  A = list(1, Amat),
  effects = list(list(Intercept = 1,
                      DCIS_1 = dat_observed$DCIS_1,
                      Myoepi_KRT15 = dat_observed$Myoepi_KRT15), 
                 list(i = 1:mesh_full$n)),
  tag = 'pp')
INLA <- inla(y ~ 0 + Intercept + DCIS_1 + Myoepi_KRT15 + f(i, model = spde_sim),
             family = 'poisson', data = inla.stack.data(stk.pp),
             control.predictor = list(A = inla.stack.A(stk.pp)),
             E = inla.stack.data(stk.pp)$e)
dnames <- unlist(lapply(strsplit(dimnames(INLA$summary.linear.predictor)[[1]], ".", fixed = T), function(x){x[1]}))
INLA$fitted.values <- INLA$summary.linear.predictor[dnames == "APredictor", ][ , "mean"]
INLA$fitted.values.sd <- INLA$summary.linear.predictor[dnames == "APredictor", ][ , "sd"]
llk_INLA_full = as.numeric(sum(INLA$fitted.values[1:N])) - as.numeric(sum(alpha_weight*exp(INLA$fitted.values[(N+1):(N+n)]))) 
llk_INLA_full_pred = as.numeric(sum(INLA$fitted.values[1:n])) - as.numeric(sum(alpha_weight*exp(INLA$fitted.values[(N+1):(N+n)])))

#--------------------------------------------- VIFRK -----------------------------------------------
library(scampr)
# model fit
dat_VIFRK = data.frame("x" = dat_observed$x,
                       "y" = dat_observed$y,
                       "pres" = rep(1:0, c(N, n)),
                       "DCIS_1" = dat_observed$DCIS_1,
                       "Myoepi_KRT15" = dat_observed$Myoepi_KRT15,
                       "quad.size" = c(rep(0, N), alpha_weight))
# VIFRK model training
VIFRK = scampr(pres~DCIS_1+Myoepi_KRT15,
               data = dat_VIFRK,
               coord.names = c("x", "y"),
               quad.weights.name = "quad.size",
               model.type = "PO", 
               basis.functions = simple_basis(10, dat_VIFRK))

llk_VIFRK_full = as.numeric(sum(VIFRK$fitted.values[1:N])) - as.numeric(sum(alpha_weight*exp(VIFRK$fitted.values[(N+1): length(VIFRK$fitted.values)])))
llk_VIFRK_full_n = as.numeric(sum(VIFRK$fitted.values[(N+1): length(VIFRK$fitted.values)])) - as.numeric(sum(alpha_weight*exp(VIFRK$fitted.values[(N+1): length(VIFRK$fitted.values)])))
# llk_VIFRK_full_pred = as.numeric(sum(VIFRK$fitted.values[N+mesh_ind_loc_test])) - as.numeric(sum(alpha_weight*exp(VIFRK$fitted.values[(N+1): length(VIFRK$fitted.values)])))
numbasis = (length(VIFRK$coefficients)-length(X[1,]))/2
Z = VIFRK$tmb.data$Z_PO_quad
VIFRK_var = Z %*% (as.vector(exp(VIFRK$coefficients[(numbasis+length(X[1,])+1): (2*numbasis+length(X[1,]))])) * t(Z))


# ---------------------------------------------- result ---------------------------------------------
logintensity_mean_full_df = data.frame(
  "VINNGP" = as.vector(X_tilde%*%beta_vinngp + mu_vinngp_n),
  "INLA" = as.vector(INLA$fitted.values[1:n]),
  "VIFRK" = as.vector(VIFRK$fitted.values[(N+1):(N+n)])
)

logintensity_sd_full_df = data.frame(
  "VINNGP" = sqrt(as.vector(Sigma_vinngp_nn_diag)),
  "INLA" = as.vector(INLA$fitted.values.sd[1:n]),
  "VIFRK" = sqrt(as.vector(diag(VIFRK_var)))
)
table_full = data.frame(
  Method = c("VoGCAM", "INLA", "VIFRK"),
  llk_N = c(llk_vinngp_full, llk_INLA_full, llk_VIFRK_full),
  llk_n = c(llk_vinngp_full_n, llk_INLA_full_pred, llk_VIFRK_full_n),
  Time = c(VINNGP_cg_lowrank2[["running_time"]], INLA$cpu.used[4], VIFRK$cpu[1])
)
table_full

# save(list = ls(all.names = TRUE), file = "./data/vinngp_full.RData", envir = .GlobalEnv)
# save(list = ls(all.names = TRUE), file = "./data/comparsion_fullcase.RData", envir = .GlobalEnv)
save(logintensity_mean_full_df, logintensity_sd_full_df, VIFRK, INLA, VINNGP_cg_lowrank2, mesh_full, table_full, file = "./data/comparsion_fullcase.RData", envir = .GlobalEnv)
