# real-data.R

library(dplyr)
library(survival)
library(COLSA)
library(coxphSGD)

#' Match column names and order of two matrices, padding zeros for missing columns
#' @param matrix1 Reference matrix
#' @param matrix2 Matrix to be matched
#' @return matrix2 with column names and order matching matrix1
match_matrices <- function(matrix1, matrix2) {
  if (ncol(matrix1) != ncol(matrix2) || !all(colnames(matrix1) == colnames(matrix2))) {
    missing_cols <- setdiff(colnames(matrix1), colnames(matrix2))
    for (col in missing_cols) {
      matrix2 <- cbind(matrix2, setNames(data.frame(rep(0, nrow(matrix2))), col))
    }
    matrix2 <- matrix2[, colnames(matrix1), drop = FALSE]
  }
  return(matrix2)
}

# Load data
load("./data/data_0306_comp.RData")

# Define formula (using original variable names, R will auto-generate dummy variables)
form <- ~ rec_age + don_age + factor(don_rec_gender) + factor(rec_bmi) + factor(don_bmi) + factor(rec_race) + factor(don_race) + factor(hla_mm) + factor(rec_diabetes) + factor(CAN_DRUG_TREAT_HYPERTEN) + factor(don_type_recat)

# Count observations and event rate for each state, sort by count descending, add order index
data_0306_comp_prop  = data_0306_comp %>%group_by(state)%>% dplyr::summarise(
  count = sum(delta_gft_failure_5_DCGF),prop=sum(delta_gft_failure_5_DCGF)/n())%>%arrange(desc(count))%>%mutate(order = order(desc(count)))

data_0306_comp = data_0306_comp%>%left_join(data_0306_comp_prop)
dim(data_0306_comp)
cat("Data dimensions:", dim(data_0306_comp), "\n")
data_0306_comp$state_factor = as.numeric(factor(data_0306_comp$order))

unique_orders <- sort(unique(data_0306_comp$state_factor))
# Use order==1 data as reference, generate reference design matrix
data_first_reg <- filter(data_0306_comp, order == 1)
X1 <- model.matrix(form, data = data_first_reg)[, -1]
# Convert column names to valid R variable names
colnames(X1) <- make.names(colnames(X1))

# Pre-allocate result storage lists
data_0306_comp_for_online <- vector("list", length(unique_orders))
loc_beta_cox <- vector("list", length(unique_orders))
loc_beta_cox_se <- vector("list", length(unique_orders))

# Loop through state_factor
for (j in unique_orders) {
  subdata <- filter(data_0306_comp, state_factor == j)
  t <- subdata$time_gft_failure_5
  d <- subdata$delta_gft_failure_5_DCGF
  X <- model.matrix(form, data = subdata)[, -1]
  colnames(X) <- make.names(colnames(X))
  X <- match_matrices(X1, X)

  data_0306_comp_for_online[[j]] <- data.frame(
    group = j,
    time = t,
    status = d,
    X
  )
}

# Combine all grouped data
data_0306_comp_for_online <- do.call(rbind, data_0306_comp_for_online)

# Get covariate names from design matrix
covariate_names <- colnames(X1)
formula <- as.formula(paste("Surv(time, status) ~", paste(covariate_names, collapse = " + ")))
boundary <- c(0, max(data_0306_comp_for_online$time))
df_sub <- subset(data_0306_comp_for_online, group == 1)
aics <- sapply(seq_len(5), function(n_basis) {
  AIC(colsa(formula, df_sub, n_basis, boundary, scale = 1))
})


n_basis_best <- which.min(aics)
alpha_best <- n_basis_best / nrow(df_sub)^0.1
cat("Best number of basis functions:", n_basis_best, "\n")
fit <- colsa(formula, df_sub, n_basis_best, boundary)
for (batch in 2:49) {
  cat("Processing batch:", batch, "\n")
  df_sub <- filter(data_0306_comp_for_online, group == batch)
  fit <- update(fit, df_sub, alpha = alpha_best, nu = 0.1) 
}

K <- 49
data_list <- split(data_0306_comp_for_online, data_0306_comp_for_online$group)
p <- length(covariate_names)

# ===== 1. Oracle Method =====
cat("\n===== Running Oracle Method =====\n")
fit_oracle <- survival::coxph(formula, data = data_0306_comp_for_online)
coef_oracle <- fit_oracle$coefficients
se_oracle <- sqrt(diag(fit_oracle$var))
z_oracle <- coef_oracle / se_oracle
cat("Oracle completed.\n")

# ===== 2. COLSA Method =====
cat("\n===== Running COLSA Method =====\n")
summary(fit)
coef_colsa <- coef(fit)
se_colsa <- sqrt(diag(vcov(fit)))
z_colsa <- coef_colsa / se_colsa
cat("COLSA completed.\n")

# ===== 3. Meta Method =====
cat("\n===== Running Meta Method =====\n")
coefs_meta <- list()
vars_meta <- list()
n_valid <- 0

for (k in seq_len(K)) {
  df_sub <- data_list[[k]]
  fit_k <- tryCatch(
    suppressWarnings(coxph(formula = formula, data = df_sub)),
    error = function(e) NULL
  )
  if (is.null(fit_k) || any(is.na(fit_k$coefficients))) next
  n_valid <- n_valid + 1
  coefs_meta[[length(coefs_meta) + 1]] <- fit_k$coefficients
  vars_meta[[length(vars_meta) + 1]] <- fit_k$var
}

cat("Meta: used", n_valid, "of", K, "batches\n")

# Inverse-variance weighting
var_invs <- lapply(vars_meta, function(v) tryCatch(solve(v), error = function(e) NULL))
var_invs <- Filter(Negate(is.null), var_invs)
coefs_meta <- coefs_meta[seq_along(var_invs)]

var_sum <- Reduce("+", var_invs)
coef_sum <- Reduce("+", mapply(
  function(coef, var_inv) var_inv %*% coef, coefs_meta, var_invs,
  SIMPLIFY = FALSE
))
coef_meta <- as.vector(solve(var_sum, coef_sum))
se_meta <- sqrt(diag(solve(var_sum)))
z_meta <- coef_meta / se_meta
cat("Meta completed.\n")

# ===== 4. Online Method =====
cat("\n===== Running Online Method =====\n")
source("online/online_update.R")

df_sub <- subset(data_0306_comp_for_online, group == 1)
fit_online <- onlinecox(formula, df_sub, n_groups = 4L, adaptive = TRUE, max_groups = 20L)
n_used <- 1

for (batch in 2:K) {
  df_sub <- subset(data_0306_comp_for_online, group == batch)
  n_before <- fit_online$n
  fit_online <- update(fit_online, df_sub, r_k = 1.5, min_events = 20L)
  if (fit_online$n > n_before) n_used <- n_used + 1
}

cat("Online: used", n_used, "of", K, "batches, g =", fit_online$g, "\n")
coef_online <- coef(fit_online)
se_online <- fit_online$se
z_online <- coef_online / se_online
cat("Online completed.\n")

# ===== 5. SGD Method =====
cat("\n===== Running SGD Method =====\n")

beta_init <- tryCatch({
  coef(coxph(formula, data = data_list[[1]]))
}, error = function(e) rep(0, p))

c_lr <- 1000
fit_sgd <- tryCatch({
  coxphSGD(
    formula = formula,
    data = data_list,
    epsilon = 1e-6,
    learn.rates = function(t) 1 / (c_lr * t^0.6),
    beta.zero = beta_init,
    max.iter = K
  )
}, error = function(e) {
  cat("SGD failed:", conditionMessage(e), "\n")
  NULL
})

if (!is.null(fit_sgd)) {
  coef_sgd <- tail(fit_sgd$coefficients, 1)[[1]]

  cat("Computing SGD standard errors...\n")
  info_cum <- matrix(0, p, p)
  beta_history <- fit_sgd$coefficients
  for (k in 1:K) {
    beta_k <- if (k == 1) beta_init else beta_history[[k]]
    fit_k <- tryCatch({
      coxph(formula, data = data_list[[k]], init = beta_k, iter.max = 0)
    }, error = function(e) NULL)
    if (!is.null(fit_k) && !any(is.na(fit_k$var))) {
      info_k <- tryCatch(solve(fit_k$var), error = function(e) NULL)
      if (!is.null(info_k)) {
        info_cum <- info_cum + info_k
      }
    }
  }
  se_sgd <- tryCatch(sqrt(diag(solve(info_cum))), error = function(e) rep(NA, p))
  z_sgd <- coef_sgd / se_sgd
} else {
  coef_sgd <- rep(NA, p)
  se_sgd <- rep(NA, p)
  z_sgd <- rep(NA, p)
}
cat("SGD completed.\n")

# ===== Comparison of All Methods =====
cat("\n===== Comparison of Methods (Coefficients) =====\n")
comparison <- data.frame(
  Oracle = coef_oracle,
  COLSA = coef_colsa,
  Meta = coef_meta,
  Online = coef_online,
  SGD = coef_sgd
)
print(round(comparison, 4))

cat("\n===== Hazard Ratios (exp(coef)) =====\n")
hr_comparison <- data.frame(
  Oracle = exp(coef_oracle),
  COLSA = exp(coef_colsa),
  Meta = exp(coef_meta),
  Online = exp(coef_online),
  SGD = exp(coef_sgd)
)
print(round(hr_comparison, 3))

# ===== Save Complete Results Table =====
cat("\n===== Saving Results =====\n")

result_table <- data.frame(
  Variable = names(coef_oracle),
  HR_Oracle = round(exp(coef_oracle), 3),
  Z_Oracle = round(z_oracle, 2),
  HR_COLSA = round(exp(coef_colsa), 3),
  Z_COLSA = round(z_colsa, 2),
  HR_Meta = round(exp(coef_meta), 3),
  Z_Meta = round(z_meta, 2),
  HR_Online = round(exp(coef_online), 3),
  Z_Online = round(z_online, 2),
  HR_SGD = round(exp(coef_sgd), 3),
  Z_SGD = round(z_sgd, 2)
)

print(result_table)

write.csv(result_table, "results/real_data_comparison.csv", row.names = FALSE)
cat("Results saved to results/real_data_comparison.csv\n")

