
source("./optimization_hyper.R")
getwd()
load("./data/nngp_prior_hole.RData")

# -----------------------------------------------------------
# optimization_hyper.R: optimization4.R + hyperparameter
nN = n+N
# beta_ini = matrix(c(log(N)-log((win$xrange[2]-win$xrange[1])*(win$yrange[2]-win$yrange[1])), 1), 2, 1)
beta_ini = matrix(c(1, 1), 2, 1)
mu_ini = matrix(0, nN, 1)
desired_rank = 50 
decomposition_result <- decompose_gamma(Gamma_precision, rank_r = desired_rank)
# 使用正确的名称提取结果
Gamma_base_diag_vec <- decomposition_result$Gamma_base_diag_vec 
L_base_factor       <- decomposition_result$L_base_factor      
r_actual            <- decomposition_result$r_actual



if (!exists("N_sum")) N_sum <- n + N 
if (is.null(L_base_factor)) {
  stop("L_base_factor is NULL after decomposition. Check decompose_gamma.")
} else if (ncol(L_base_factor) != r_actual) {
  warning("ncol(L_base_factor) does not match r_actual. Using ncol.")
  r_prior_dim <- ncol(L_base_factor) # Use actual dimension from matrix
} else {
  r_prior_dim <- r_actual # Use the value from decomposition result
}
Sigma_ini= list(diag = 1 / Gamma_base_diag_vec,             
                W = matrix(0.0, nrow=N_sum, ncol=r_prior_dim), 
                L = L_base_factor)     
sigma2_ini = 1
a = 0.01
b = 0.01


VINNGP_cg_lowrank2_hole = HVGA_new_lowrank2(A, A_tilde, 
                                       X, X_tilde, alpha_weight,
                                       mu_ini, Sigma_ini, beta_ini, sigma2_ini, # Initial values
                                       Gamma_precision, # Base precision (unscaled)
                                       Gamma_base_diag_vec, L_base_factor, # Decomposed base prior parts
                                       a, b, # IG prior parameters for sigma2
                                       maxiter = 50, tol = c(5,0.05), sigma2_tol = 1e-4, # Outer loop tolerances
                                       inner_maxiter = 2, # Max iterations for inner loops
                                       inner_tol_mu = 1e-2,  # Inner tolerance for mu update
                                       inner_tol_sigma = 1e-2, # Inner tolerance for Sigma update
                                       inner_tol_beta = 1e-2, # Inner tolerance for beta update
                                       mu_update_method = "cg",
                                       cg_tol = 1e-6, cg_maxiter = 1, cg_step = c(0.01, 0.05))

# result
beta_vinngp_hole = VINNGP_cg_lowrank2_hole$beta_hist[[length(VINNGP_cg_lowrank2_hole$beta_hist)]]
mu_vinngp_n_hole = VINNGP_cg_lowrank2_hole$mu[1:n]
mu_vinngp_pred_hole = VINNGP_cg_lowrank2_hole$mu[mesh_ind_loc_test]
mu_vinngp_N_hole = VINNGP_cg_lowrank2_hole$mu[(n+1):(n+N)]
sigma2_vinngp_hole = VINNGP_cg_lowrank2_hole$sigma2

Sigma_vinngp_list_hole = VINNGP_cg_lowrank2_hole$Sigma_lowrank
L_base_factor = Sigma_vinngp_list_hole$L
r_prior = ncol(L_base_factor)
I_r = Matrix::Diagonal(r_prior)
D_inv_vec = Sigma_vinngp_list_hole$diag # = (H_tilde_diag)^{-1}

# Calculate diagonal of Sigma_k = D_inv - D_inv L' (I+M)^{-1} L'^T D_inv
# M = L'^T D_inv L' = (1/sigma2_k) * L_base^T D_inv L_base
M = (1/sigma2_vinngp_hole) * Matrix::crossprod(L_base_factor, D_inv_vec * L_base_factor)

# Middle_inv = solve(I + M)
Middle_inv = tryCatch({
  Matrix::solve(I_r + M)
}, error = function(e){
  warning("Matrix solve failed for (I+M) in exp_WeightedVec_lowrank. Check M's condition number. Using pseudo-inverse.", call. = FALSE)
  # Fallback or error handling, e.g., using MASS::ginv or stopping
  MASS::ginv(as.matrix(I_r + M)) # Ensure MASS is loaded or handle differently
})

# L' = L_base / sqrt(sigma2_vinngp)
L_prime = (1/sqrt(sigma2_vinngp_hole)) * L_base_factor

# Precompute V = L' %*% Middle_inv
V = L_prime %*% Middle_inv  # N_sum x r

# Calculate diag(Sigma_k)_i = D_inv_i - D_inv_i^2 * sum_j V[i,j] * L_prime[i,j]
# More efficiently: D_inv_i - D_inv_i^2 * rowSums(V * L_prime)
Correction_term = D_inv_vec^2 * rowSums(V * L_prime)
Sigma_k_diag_elems = D_inv_vec - Correction_term

# Check for negative variances (shouldn't happen if stable)
if(any(Sigma_k_diag_elems <= 0)){
  warning("Non-positive diagonal elements calculated for Sigma_k in exp_WeightedVec_lowrank. Clamping.")
  Sigma_k_diag_elems = pmax(Sigma_k_diag_elems, .Machine$double.eps)
}

Sigma_vinngp_nn_diag_hole = Sigma_k_diag_elems[1:n]

# llk_VISPDE_full = as.numeric(sum(as.vector(X_cov1%*%beta_opt + B_mat%*%mu_opt)) - sum(alpha_weight*exp(as.vector(X_cov2%*%beta_opt + A_mat%*%mu_opt))))
llk_vinngp_hole = as.numeric(sum(as.vector(X%*%beta_vinngp_hole + mu_vinngp_N_hole)) - sum(alpha_weight*exp(as.vector(X_tilde%*%beta_vinngp_hole + mu_vinngp_n_hole))))
llk_vinngp_hole_n = as.numeric(sum(as.vector(X_tilde%*%beta_vinngp_hole + mu_vinngp_n_hole)) - sum(alpha_weight*exp(as.vector(X_tilde%*%beta_vinngp_hole + mu_vinngp_n_hole))))
# llk_vinngp_hole_pred = as.numeric(sum(as.vector(X_test%*%beta_vinngp_hole +mu_vinngp_pred_hole)) - sum(alpha_weight*exp(as.vector(X_tilde%*%beta_vinngp_hole + mu_vinngp_n_hole))))
logintensity_vinngp_hole = as.vector(X_tilde%*%beta_vinngp_hole + mu_vinngp_n_hole)

#------------------------------------ INLA -----------------------------------------------
y0 = x0 = seq(win$xrange[1], win$xrange[2], length = spatstat.options()$npixel)
spde_sim = inla.spde2.pcmatern(
  mesh = mesh_full, 
  alpha = 2,
  prior.range = c(0.05, 0.01), ### P(practic.range<0.05)=0.01
  prior.sigma = c(1, 0.01) ### P(sigma>1)=0.01
)
covariate_sim = gridcov[Reduce('cbind', nearest.pixel(
  c(mesh_vertices[,1], xy.c_observed[,1]), c(mesh_vertices[,2], xy.c_observed[,2]),
  im(gridcov, x0, y0)))]
nv = n
n.c = N
y.pp.c = rep(0:1, c(nv, n.c))
e.pp.c = c(alpha_weight, rep(0, n.c))
lmat.c = inla.spde.make.A(mesh_full, xy.c_observed)
imat = Diagonal(nv, rep(1, nv))
A.pp.c = rbind(imat, lmat.c)

# data stack
stk.pp = inla.stack(data = list(y=y.pp.c, e=e.pp.c),
                    A = list(1, A.pp.c), tag = 'pp.c',
                    effects = list(list(b0=1, covariate = covariate_sim),
                                   list(i=1:nv)))
# INLA model training
pp.c.res_hole = inla(y ~ 0 + b0 + covariate + f(i, model=spde_sim),
                     family = 'poisson', data = inla.stack.data(stk.pp),
                     control.predictor = list(A=inla.stack.A(stk.pp), compute = TRUE),
                     E = inla.stack.data(stk.pp)$e)

dnames <- unlist(lapply(strsplit(dimnames(pp.c.res_hole$summary.linear.predictor)[[1]], ".", fixed = T), function(x){x[1]}))
pp.c.res_hole$fitted.values <- pp.c.res_hole$summary.linear.predictor[dnames == "APredictor", ][ , "mean"]
pp.c.res_hole$fitted.values.sd <- pp.c.res_hole$summary.linear.predictor[dnames == "APredictor", ][ , "sd"]
llk_INLA_hole = as.numeric(sum(pp.c.res_hole$fitted.values[(n+1):(N+n)])) - as.numeric(sum(alpha_weight*exp(pp.c.res_hole$fitted.values[1:n]))) 
llk_INLA_hole_n = as.numeric(sum(pp.c.res_hole$fitted.values[1:n])) - as.numeric(sum(alpha_weight*exp(pp.c.res_hole$fitted.values[1:n]))) 
# llk_INLA_hole_pred = as.numeric(sum(pp.c.res_hole$fitted.values[mesh_ind_loc_test])) - as.numeric(sum(alpha_weight*exp(pp.c.res_hole$fitted.values[1:n])))

#------------------------------------ VIFRK ----------------------------------------------
library(scampr)
# model fit
dat_VIFRK_hole = data.frame("x" = c(xy.c_observed[ ,1], mesh_vertices[ ,1]),
                            "y" = c(xy.c_observed[ ,2], mesh_vertices[ ,2]),
                            "pres" = rep(1:0, c(N, n)),
                            "covariates" = gridcov[Reduce('cbind', nearest.pixel(
                              c(xy.c_observed[,1], mesh_vertices[,1]), c(xy.c_observed[,2], mesh_vertices[,2]),
                              im(gridcov, x0, y0)))],
                            "quad.size" = c(rep(0, N), alpha_weight))
# VIFRK model training
VIFRK_hole = scampr(pres~covariates,
                    data = dat_VIFRK_hole,
                    coord.names = c("x", "y"),
                    quad.weights.name = "quad.size",
                    model.type = "PO", 
                    basis.functions = simple_basis(10, dat_VIFRK_hole))

llk_VIFRK_hole = as.numeric(sum(VIFRK_hole$fitted.values[1:N])) - as.numeric(sum(alpha_weight*exp(VIFRK_hole$fitted.values[(N+1): length(VIFRK_hole$fitted.values)])))
llk_VIFRK_hole_n = as.numeric(sum(VIFRK_hole$fitted.values[(N+1): length(VIFRK_hole$fitted.values)])) - as.numeric(sum(alpha_weight*exp(VIFRK_hole$fitted.values[(N+1): length(VIFRK_hole$fitted.values)])))
# llk_VIFRK_hole_pred = as.numeric(sum(VIFRK_hole$fitted.values[N+mesh_ind_loc_test])) - as.numeric(sum(alpha_weight*exp(VIFRK_hole$fitted.values[(N+1): length(VIFRK_hole$fitted.values)])))
numbasis = (length(VIFRK_hole$coefficients)-length(X[1,]))/2
Z = VIFRK_hole$tmb.data$Z_PO_quad
VIFRK_var_hole = Z %*% (as.vector(exp(VIFRK_hole$coefficients[(numbasis+length(X[1,])+1): (2*numbasis+length(X[1,]))])) * t(Z))

#--------------------------------------- result --------------------------------------------
logintensity_mean_hole_df = data.frame(
  "VINNGP" = as.vector(logintensity_vinngp_hole),
  "INLA" = as.vector(pp.c.res_hole$fitted.values[1:n]),
  "VIFRK" = as.vector(VIFRK_hole$fitted.values[(N+1):(N+n)])
)

logintensity_sd_hole_df = data.frame(
  "VINNGP" = sqrt(as.vector(Sigma_vinngp_nn_diag_hole)),
  "INLA" = as.vector(pp.c.res_hole$fitted.values.sd[1:n]),
  "VIFRK" = sqrt(as.vector(diag(VIFRK_var_hole)))
)
table_hole = data.frame(
  Method = c("VoGCAM", "INLA", "VIFRK"),
  llk_N = c(llk_vinngp_hole, llk_INLA_hole, llk_VIFRK_hole),
  llk_n = c(llk_vinngp_hole_n, llk_INLA_hole_n, llk_VIFRK_hole_n),
  Time = c(VINNGP_cg_lowrank2_hole[["running_time"]], pp.c.res_hole$cpu.used[4], VIFRK_hole$cpu[1])
)
table_hole

# save(list = ls(all.names = TRUE), file = "./data/vinngp_hole.RData", envir = .GlobalEnv)
save(logintensity_mean_hole_df,logintensity_sd_hole_df, VINNGP_cg_lowrank2_hole, pp.c.res_hole, VIFRK_hole, table_hole, dual_mesh_plot, file = "./data/comparsion_holecase.RData", envir = .GlobalEnv)
