# ============================================================================
# Online Cox Model with Piecewise Constant Hazard
# ============================================================================
# Reference: Wu et al. (2021) "Online Updating of Survival Analysis"
#            Journal of Computational and Graphical Statistics
#            DOI: 10.1080/10618600.2020.1870481
#
# Implements:
#   - Section 2.3: Fixed Partition
#   - Section 2.4: Adaptive Partition
#   - Section 2.5: Fixed Partition and Bias Correction
#   - Section 2.6: Adaptive Partition and Bias Correction
# ============================================================================

safe_exp <- function(x) ifelse(x <= 0, exp(x), 1 / exp(-x))

# ============================================================================
# Profile Log-likelihood for β
# ============================================================================
# Paper Formula (2.5), page 2:
#   "ℓ(β, λ|D) = Σ_{j=1}^J d_j log λ_j + Σ_{i=1}^N δ_i x_i^T β
#                - Σ_{j=1}^J λ_j { Σ_{i=1}^N Δ_j(t_i) exp(x_i^T β) }"
#
# Paper Formula (2.6), page 3 - solving ∂ℓ/∂λ_j = 0:
#   "λ_j = d_j / Σ_{i=1}^N Δ_j(t_i) exp(x_i^T β)"
#
# Profile likelihood: substitute λ_j from (2.6) into (2.5), remove constants:
#   ℓ(β) ∝ Σ_i δ_i x_i^T β - Σ_j d_j log(Σ_i Δ_j(t_i) exp(x_i^T β))
# ============================================================================
neg_loglik <- function(beta, x, status, d, delta, g) {
  # exp(x_i^T β)
  exp_xb <- safe_exp(x %*% beta)
  # Σ_{i=1}^N Δ_j(t_i) exp(x_i^T β) for each j
  denom <- as.vector(crossprod(delta[, 1:g], exp_xb))
  pos <- d[1:g] > 0
  # Negative profile log-likelihood (for minimization)
  # = -[Σ_i δ_i x_i^T β - Σ_j d_j log(denom_j)]
  -(sum(x[status == 1, ] %*% beta) - sum(d[1:g][pos] * log(denom[pos])))
}

# ============================================================================
# λ_j from Score Equation
# ============================================================================
# Paper Formula (2.6), page 3:
#   "d_j/λ_j - Σ_{i=1}^N Δ_j(t_i) exp(x_i^T β) = 0"
#
# Solving for λ_j:
#   "λ_j = d_j / Σ_{i=1}^N Δ_j(t_i) exp(x_i^T β)"
# ============================================================================
compute_lambda <- function(x, beta, g, d, delta) {
  # Σ_{i=1}^N Δ_j(t_i) exp(x_i^T β) for each j
  denom <- as.vector(crossprod(delta[, 1:g], safe_exp(x %*% beta)))
  # λ_j = d_j / denom_j
  ifelse(d[1:g] > 0, d[1:g] / denom, 0)
}

# ============================================================================
# Score Function M(θ)
# ============================================================================
# Paper Formula (2.6), page 3:
#   "M(θ) = [Σ_{i=1}^N δ_i x_{i1} - Σ_{j=1}^J λ_j Σ_{i=1}^N Δ_j(t_i) exp(x_i^T β) x_{i1}
#           ...
#           Σ_{i=1}^N δ_i x_{ip} - Σ_{j=1}^J λ_j Σ_{i=1}^N Δ_j(t_i) exp(x_i^T β) x_{ip}
#           d_1/λ_1 - Σ_{i=1}^N Δ_1(t_i) exp(x_i^T β)
#           ...
#           d_J/λ_J - Σ_{i=1}^N Δ_J(t_i) exp(x_i^T β)]^T = 0"
#
# Score for β_r:
#   ∂ℓ/∂β_r = Σ_i δ_i x_{ir} - Σ_j λ_j Σ_i Δ_j(t_i) exp(x_i^T β) x_{ir}
#           = Σ_i δ_i x_{ir} - Σ_i x_{ir} exp(x_i^T β) (Σ_j λ_j Δ_j(t_i))
#
# Score for λ_j:
#   ∂ℓ/∂λ_j = d_j/λ_j - Σ_i Δ_j(t_i) exp(x_i^T β)
# ============================================================================
compute_score <- function(beta, lambda, x, status, d, delta, g, p) {
  # exp(x_i^T β)
  exp_xb <- as.vector(safe_exp(x %*% beta))
  # Δ_j(t_i) matrix (N × J)
  delta_g <- delta[, 1:g, drop = FALSE]

  # Score for β: Σ_i δ_i x_i - Σ_i x_i exp(x_i^T β) (Σ_j λ_j Δ_j(t_i))
  # = crossprod(x, status) - crossprod(x, w)
  # where w_i = exp(x_i^T β) × (Σ_j λ_j Δ_j(t_i))
  w <- as.vector(exp_xb * (delta_g %*% lambda))
  score_beta <- as.vector(crossprod(x, status) - crossprod(x, w))

  # Score for λ_j: d_j/λ_j - Σ_i Δ_j(t_i) exp(x_i^T β)
  # = d_j/λ_j - crossprod(delta_g, exp_xb)[j]
  denom <- as.vector(crossprod(delta_g, exp_xb))
  score_lambda <- ifelse(d[1:g] > 0 & lambda > 0, d[1:g] / lambda - denom, 0)

  c(score_beta, score_lambda)
}

# ============================================================================
# Hessian Matrix H
# ============================================================================
# Paper Formula (2.7), page 3:
#   "H_{r,s} = -∂²ℓ/∂β_r∂β_s = Σ_{j=1}^J λ_j { Σ_{i=1}^N Δ_j(t_i) x_{ir} x_{is} exp(x_i^T β) }
#
#    H_{p+m,r} = H_{r,p+m} = -∂²ℓ/∂λ_m∂β_r = Σ_{i=1}^N Δ_m(t_i) x_{ir} exp(x_i^T β)
#
#    H_{p+m,p+n} = -∂²ℓ/∂λ_m∂λ_n = 1_{(m=n)} d_m/λ_m²"
#
# Note: H is the negated Hessian (i.e., -∂²ℓ/∂θ∂θ^T)
# ============================================================================
compute_hessian <- function(x, beta, lambda, p, g, d, delta) {
  # exp(x_i^T β)
  exp_xb <- as.vector(safe_exp(x %*% beta))
  # Δ_j(t_i) matrix (N × J)
  delta_g <- delta[, 1:g, drop = FALSE]
  H <- matrix(0, p + g, p + g)

  # H_{r,s} = Σ_j λ_j { Σ_i Δ_j(t_i) x_{ir} x_{is} exp(x_i^T β) }
  #         = Σ_i x_{ir} x_{is} exp(x_i^T β) (Σ_j λ_j Δ_j(t_i))
  #         = crossprod(x, w * x)  where w_i = exp(x_i^T β) × (Σ_j λ_j Δ_j(t_i))
  w <- as.vector(exp_xb * (delta_g %*% lambda))
  H[1:p, 1:p] <- crossprod(x, w * x)

  # H_{p+m,r} = Σ_i Δ_m(t_i) x_{ir} exp(x_i^T β)
  #           = crossprod(x, exp_xb * delta_g)
  H[1:p, (p + 1):(p + g)] <- crossprod(x, exp_xb * delta_g)

  # H_{r,p+m} = H_{p+m,r}^T (symmetric)
  H[(p + 1):(p + g), 1:p] <- t(H[1:p, (p + 1):(p + g)])

  # H_{p+m,p+n} = 1_{(m=n)} d_m/λ_m² (diagonal matrix)
  lambda_diag <- ifelse(d[1:g] > 0 & lambda > 0, d[1:g] / lambda^2, 0)
  for (j in 1:g) H[p + j, p + j] <- lambda_diag[j]

  H
}

# ============================================================================
# Full Variance Matrix V = H^{-1}
# ============================================================================
# Returns the full (p+g) × (p+g) variance matrix
# ============================================================================
compute_full_variance <- function(H) {
  solve(H)
}

# ============================================================================
# Variance for β via Schur Complement
# ============================================================================
# V(β) = (H_ββ - H_βλ H_λλ^{-1} H_λβ)^{-1}
# This is derived from the block matrix inversion of H^{-1}
# ============================================================================
compute_variance_beta <- function(H, p, g) {
  Hbb <- H[1:p, 1:p]
  Hbl <- H[1:p, (p + 1):(p + g)]
  Hll <- H[(p + 1):(p + g), (p + 1):(p + g)]
  solve(Hbb - Hbl %*% solve(Hll) %*% t(Hbl))
}

# ============================================================================
# Partition Breaks based on Event Time Quantiles
# ============================================================================
compute_breaks <- function(event_times, g) {
  te <- sort(event_times)
  n_events <- length(te)
  if (g <= 1) {
    # Single interval: [0, max(t)*2)
    return(c(0, max(te) * 2))
  }
  if (n_events < g) {
    # Not enough events, use equal spacing
    return(c(0, seq(min(te), max(te), length.out = g)[-1], max(te) * 2))
  }
  # Quantile-based breaks
  c(0, sapply(1:(g - 1), function(i) {
    pos <- n_events * i / g
    idx <- floor(pos)
    if (pos == idx && idx + 1 <= n_events) {
      (te[idx] + te[idx + 1]) / 2
    } else if (idx + 1 <= n_events) {
      te[idx + 1]
    } else {
      te[n_events]
    }
  }), max(te) * 2)
}

# ============================================================================
# Compute d_j and Δ_j(t_i)
# ============================================================================
# Paper Formula (2.5), page 2:
#   "d_j = Σ_{i=1}^N δ_i 1_{[a_{j-1},a_j)}(t_i)"
#   (number of events in interval j)
#
# Paper Formula (2.3), page 2:
#   "Δ_j(t) = { 0              if t < a_{j-1}
#             { t - a_{j-1}    if a_{j-1} ≤ t < a_j
#             { a_j - a_{j-1}  if t ≥ a_j"
#   (exposure time in interval j)
# ============================================================================
compute_events <- function(time, status, breaks, g) {
  list(
    # d_j = Σ_i δ_i 1_{[a_{j-1}, a_j)}(t_i)
    d = sapply(1:g, function(j) {
      sum(status[time >= breaks[j] & time < breaks[j + 1]])
    }),
    # Δ_j(t_i) computed as: pmax(0, pmin(t, a_j) - a_{j-1})
    # This implements all three cases of Formula (2.3)
    delta = sapply(1:g, function(j) {
      pmax(0, pmin(time, breaks[j + 1]) - breaks[j])
    })
  )
}

# ============================================================================
# Section 2.4: Expansion Matrix P for Adaptive Partition
# ============================================================================
# Paper page 4-5:
#   "To be specific, we introduce the (p+J+1)×(p+J) expansion matrix P_{k-1},
#    where P_{k-1}(i,i) = 1, i = 1,...,(p+J-1),
#    P_{k-1}(p+J, p+J) = w_J^p, P_{k-1}(p+J+1, p+J) = w_{J+1}^p, and 0 elsewhere."
#
#   "H*_{k-1} = P_{k-1} H_{k-1} P_{k-1}^T"  (Hessian expansion)
#
# Paper Remark 2.4.1, page 5:
#   "We further impose constraints on w_J^p, w_{J+1}^p, w_J^q, and w_{J+1}^q
#    (w_J^p w_J^q + w_{J+1}^p w_{J+1}^q = 1 and all are positive)"
#   "we set w_J^p = w_{J+1}^p = 0.5"
#
# Note: Code generalizes to split any interval j_split (not just J-th)
# ============================================================================
create_P <- function(p, g, j_split, wp_j = 0.5, wp_j1 = 0.5) {
  P <- matrix(0, p + g + 1, p + g)
  # P(i,i) = 1 for i = 1,...,(p+j_split-1)
  for (i in 1:(p + j_split - 1)) P[i, i] <- 1
  # P(p+j_split, p+j_split) = w_j^p
  P[p + j_split, p + j_split] <- wp_j
  # P(p+j_split+1, p+j_split) = w_{j+1}^p
  P[p + j_split + 1, p + j_split] <- wp_j1
  # Shift remaining λ's: P(p+i+1, p+i) = 1 for i = j_split+1,...,g
  if (j_split < g) {
    for (i in (j_split + 1):g) {
      P[p + i + 1, p + i] <- 1
    }
  }
  P
}

# ============================================================================
# Section 2.4: Expansion Matrix Q for Adaptive Partition
# ============================================================================
# Paper page 5:
#   "We further introduce the (p+J+1)×(p+J) expansion matrix Q_{k-1},
#    where Q_{k-1}(i,i) = 1, i = 1,...,(p+J-1),
#    Q_{k-1}(p+J, p+J) = w_J^q, Q_{k-1}(p+J+1, p+J) = w_{J+1}^q, and 0 elsewhere."
#
#   "θ*_{k-1} = Q_{k-1} θ_{k-1}"  (parameter expansion)
#
# Paper Remark 2.4.1, page 5:
#   "we set w_J^q = w_{J+1}^q = 1"
#
# Constraint verification:
#   w_J^p × w_J^q + w_{J+1}^p × w_{J+1}^q = 0.5×1 + 0.5×1 = 1 ✓
#   This ensures P^T Q = I_{p+J}
# ============================================================================
create_Q <- function(p, g, j_split, wq_j = 1, wq_j1 = 1) {
  Q <- matrix(0, p + g + 1, p + g)
  # Q(i,i) = 1 for i = 1,...,(p+j_split-1)
  for (i in 1:(p + j_split - 1)) Q[i, i] <- 1
  # Q(p+j_split, p+j_split) = w_j^q
  Q[p + j_split, p + j_split] <- wq_j
  # Q(p+j_split+1, p+j_split) = w_{j+1}^q
  Q[p + j_split + 1, p + j_split] <- wq_j1
  # Shift remaining λ's: Q(p+i+1, p+i) = 1 for i = j_split+1,...,g
  if (j_split < g) {
    for (i in (j_split + 1):g) {
      Q[p + i + 1, p + i] <- 1
    }
  }
  Q
}

# ============================================================================
# Split breaks at interval j_split
# ============================================================================
split_breaks <- function(breaks, j_split, event_times) {
  events_in_interval <- event_times[
    event_times >= breaks[j_split] & event_times < breaks[j_split + 1]
  ]
  if (length(events_in_interval) >= 2) {
    new_break <- median(events_in_interval)
  } else {
    new_break <- (breaks[j_split] + breaks[j_split + 1]) / 2
  }
  c(breaks[1:j_split], new_break, breaks[(j_split + 1):length(breaks)])
}

# ============================================================================
# Main Function: onlinecox (initialization with first batch)
# ============================================================================
#' Online Cox Model with Adaptive Partition
#' @param formula Survival formula
#' @param data Initial batch data
#' @param n_groups Initial number of intervals (J_0)
#' @param adaptive Use adaptive partition (default TRUE)
#' @param max_groups Maximum number of intervals (J_max)
#' @export
onlinecox <- function(formula, data, n_groups = 4L, adaptive = TRUE, max_groups = 50L) {
  mf <- model.frame(formula, data)
  y <- model.response(mf)
  x <- model.matrix(formula, data)[, -1, drop = FALSE]
  time <- y[, 1]
  status <- y[, 2]
  n <- nrow(x)
  p <- ncol(x)
  g <- n_groups

  event_times <- time[status == 1]
  breaks <- compute_breaks(event_times, g)
  ev <- compute_events(time, status, breaks, g)

  # θ̂_1 = MLE from first batch (Section 2.3)
  beta <- nlm(function(b) neg_loglik(b, x, status, ev$d, ev$delta, g),
              rep(0, p), fscale = n)$estimate
  lambda <- compute_lambda(x, beta, g, ev$d, ev$delta)

  # H̃_1 = H(θ̂_1)
  H <- compute_hessian(x, beta, lambda, p, g, ev$d, ev$delta)
  # Paper Formula (2.32): Ṽ_1 = H̃_1^{-1} (first batch, full variance matrix)
  V_full <- compute_full_variance(H)
  V_beta <- compute_variance_beta(H, p, g)

  structure(list(
    coef = beta,
    se = sqrt(pmax(diag(V_beta), 0)),
    theta = c(beta, lambda),
    hessian = H,
    variance = V_full,  # Full (p+g) × (p+g) variance for Formula (2.32)
    p = p,
    g = g,
    n = n,
    breaks = breaks,
    event_times = event_times,
    adaptive = adaptive,
    max_groups = max_groups,
    formula = formula
  ), class = "onlinecox")
}

# ============================================================================
# Main Function: update.onlinecox
# Implements Section 2.6: Adaptive Partition and Bias Correction
# ============================================================================
#' @export
update.onlinecox <- function(object, newdata, r_k = 1, min_events = 10L, ...) {
  mf <- model.frame(object$formula, newdata)
  y <- model.response(mf)
  x <- model.matrix(object$formula, newdata)[, -1, drop = FALSE]
  time <- y[, 1]
  status <- y[, 2]
  n <- nrow(x)
  p <- object$p
  g <- object$g

  # Skip batch if too few events
  n_events <- sum(status)
  if (n_events < min_events) {
    return(object)
  }

  # Wrap entire update in tryCatch - skip batch on numerical error
  tryCatch({
    new_event_times <- time[status == 1]
    all_event_times <- c(object$event_times, new_event_times)

    # Maximum condition number threshold for numerical stability
    max_cond <- 1e6

  # ==========================================================================
  # Paper Remark 2.4.4, page 5:
  #   "For the kth block of data, we partition the j_{max}th interval into
  #    two subintervals, where j_{max} = argmax_j {j | d_j > r_k · (Σ_{ℓ=1}^J d_ℓ)/J}
  #    with expansion rate r_k ≥ 1, for k = 1,...,K.
  #    In this article, we set the expansion rate r_k = 1 throughout."
  # ==========================================================================
  need_split <- FALSE
  j_split <- NULL

  if (object$adaptive && g < object$max_groups) {
    ev_check <- compute_events(time, status, object$breaks, g)
    d_current <- ev_check$d
    # mean(d) = (Σ_{ℓ=1}^J d_ℓ) / J
    mean_d <- sum(d_current) / g
    # threshold = r_k × mean(d)
    threshold <- r_k * mean_d

    # j_{max} = argmax_j {j | d_j > threshold}
    # = max{j | d_j > threshold} (rightmost interval exceeding threshold)
    candidates <- which(d_current > threshold)
    if (length(candidates) > 0) {
      j_split <- max(candidates)
      need_split <- TRUE
    }
  }

  # ==========================================================================
  # Paper Section 2.4, page 5 - Expansion when splitting:
  #   "H*_{k-1} = P_{k-1} H_{k-1} P_{k-1}^T"
  #   "θ*_{k-1} = Q_{k-1} θ_{k-1}"
  #
  # Paper Formula (2.32) for variance expansion:
  #   "Ṽ*_{k-1} = Q_{k-1} Ṽ_{k-1} Q_{k-1}^T (split)"
  # ==========================================================================
  if (need_split) {
    # Remark 2.4.1: w_J^p = w_{J+1}^p = 0.5, w_J^q = w_{J+1}^q = 1
    P <- create_P(p, g, j_split, wp_j = 0.5, wp_j1 = 0.5)
    Q <- create_Q(p, g, j_split, wq_j = 1, wq_j1 = 1)

    # H̃*_{k-1} = P_{k-1} H̃_{k-1} P_{k-1}^T
    H_tilde_prev <- P %*% object$hessian %*% t(P)
    # θ̃*_{k-1} = Q_{k-1} θ̃_{k-1}
    theta_tilde_prev <- as.vector(Q %*% object$theta)
    # Ṽ*_{k-1} = Q_{k-1} Ṽ_{k-1} Q_{k-1}^T (Formula 2.32)
    V_tilde_prev <- Q %*% object$variance %*% t(Q)

    breaks_new <- split_breaks(object$breaks, j_split, all_event_times)
    g_new <- g + 1
  } else {
    H_tilde_prev <- object$hessian
    theta_tilde_prev <- object$theta
    V_tilde_prev <- object$variance  # Ṽ*_{k-1} = Ṽ_{k-1} (no split)
    breaks_new <- object$breaks
    g_new <- g
  }

  ev <- compute_events(time, status, breaks_new, g_new)

  # ==========================================================================
  # Step 1: Batch MLE θ̂_{n_k,k} from current block
  # Paper Section 2.3, Formula (2.6): solve M(θ) = 0
  # ==========================================================================
  beta_hat <- nlm(function(b) neg_loglik(b, x, status, ev$d, ev$delta, g_new),
                  rep(0, p), fscale = n)$estimate
  lambda_hat <- compute_lambda(x, beta_hat, g_new, ev$d, ev$delta)
  theta_hat <- c(beta_hat, lambda_hat)

  # ==========================================================================
  # Step 2: Hessian H_{n_k,k} at batch MLE θ̂_{n_k,k}
  # Paper Formula (2.7)
  # Note: This is H (not H̃), evaluated at batch MLE
  # ==========================================================================
  H_k <- compute_hessian(x, beta_hat, lambda_hat, p, g_new, ev$d, ev$delta)
  # V̂_{n_k,k} = H_{n_k,k}^{-1} (batch variance at MLE)
  V_hat_k <- compute_full_variance(H_k)

  # ==========================================================================
  # Step 3: Intermediary Estimator θ̌_{n_k,k}
  # Paper Remark 2.6.1, page 7:
  #   "If we allow for increasing number of intervals, θ̌_{n_k,k} in (2.26) becomes
  #    θ̌_{n_k,k} = (H̃*_{k-1} + H_{n_k,k})^{-1} (H̃*_{k-1} θ̃*_{k-1} + H_{n_k,k} θ̂_{n_k,k})"
  #
  # Note: Uses H_{n_k,k} (Hessian at batch MLE), not H̃_{n_k,k}
  # ==========================================================================
  H_sum <- H_tilde_prev + H_k
  theta_check <- as.vector(solve(H_sum) %*% (H_tilde_prev %*% theta_tilde_prev + H_k %*% theta_hat))
  beta_check <- theta_check[1:p]
  lambda_check <- pmax(theta_check[(p + 1):(p + g_new)], 1e-10)

  # ==========================================================================
  # Step 4: Hessian H̃_{n_k,k} at Intermediary θ̌_{n_k,k}
  # Paper page 6: "Defining H̃_{n_ℓ,ℓ} = [H_{n_ℓ,ℓ}(θ̌_{n_ℓ,ℓ})]"
  # This is H̃ (tilde), evaluated at the intermediary estimator
  # ==========================================================================
  H_tilde_k <- compute_hessian(x, beta_check, lambda_check, p, g_new, ev$d, ev$delta)

  # ==========================================================================
  # Step 5: Score M_{n_k,k}(θ̌_{n_k,k}) at Intermediary
  # Paper Formula (2.6) evaluated at θ̌
  # This is the bias correction term in Formula (2.31)
  # ==========================================================================
  M_k <- compute_score(beta_check, lambda_check, x, status, ev$d, ev$delta, g_new, p)

  # ==========================================================================
  # Step 6: Final Update with Bias Correction
  # Paper Formula (2.31), page 6:
  #   "θ̃_k = {H̃*_{k-1} + H̃_{n_k,k}}^{-1}
  #          × {H̃*_{k-1} θ̃*_{k-1} + H̃_{n_k,k} θ̌_{n_k,k} + M_{n_k,k}(θ̌_{n_k,k})}"
  #
  # where:
  #   H̃*_{k-1} = expanded cumulative Hessian (= P H̃_{k-1} P^T if split)
  #   H̃_{n_k,k} = Hessian at intermediary for current block
  #   θ̃*_{k-1} = expanded cumulative estimate (= Q θ̃_{k-1} if split)
  #   θ̌_{n_k,k} = intermediary estimator from Step 3
  #   M_{n_k,k}(θ̌) = score at intermediary (bias correction term)
  # ==========================================================================
  H_tilde_sum <- H_tilde_prev + H_tilde_k
  theta_tilde_new <- as.vector(solve(H_tilde_sum) %*%
                (H_tilde_prev %*% theta_tilde_prev + H_tilde_k %*% theta_check + M_k))
  beta_new <- theta_tilde_new[1:p]
  lambda_new <- pmax(theta_tilde_new[(p + 1):(p + g_new)], 1e-10)

  # ==========================================================================
  # Step 7: Update Cumulative Hessian
  # H̃_k = H̃*_{k-1} + H̃_{n_k,k}
  # ==========================================================================
  H_tilde_new <- H_tilde_sum

  # ==========================================================================
  # Variance Estimation
  # Paper Formula (2.32), page 6:
  #   "An approximate variance estimator of θ̃_k is given by
  #    Ṽ_k = (H̃*_{k-1} + H̃_{n_k,k})^{-1}
  #          × (H̃*_{k-1} Ṽ*_{k-1} H̃*_{k-1}^T + H̃_{n_k,k} V̂_{n_k,k} H̃_{n_k,k}^T)
  #          × [(H̃*_{k-1} + H̃_{n_k,k})^{-1}]^T,
  #    where Ṽ*_{k-1} = Ṽ_{k-1} (no split) or Q_{k-1} Ṽ_{k-1} Q_{k-1}^T (split)"
  #
  # Implementation of full sandwich form:
  #   H̃_sum^{-1} = (H̃*_{k-1} + H̃_{n_k,k})^{-1}
  #   middle = H̃*_{k-1} Ṽ*_{k-1} H̃*_{k-1}^T + H̃_{n_k,k} V̂_{n_k,k} H̃_{n_k,k}^T
  #   Ṽ_k = H̃_sum^{-1} × middle × H̃_sum^{-1}^T
  # ==========================================================================
  H_sum_inv <- solve(H_tilde_sum)
  # middle = H̃*_{k-1} Ṽ*_{k-1} H̃*_{k-1}^T + H̃_{n_k,k} V̂_{n_k,k} H̃_{n_k,k}^T
  middle <- H_tilde_prev %*% V_tilde_prev %*% t(H_tilde_prev) +
    H_tilde_k %*% V_hat_k %*% t(H_tilde_k)
  # Ṽ_k = H̃_sum^{-1} × middle × (H̃_sum^{-1})^T
  V_full_new <- H_sum_inv %*% middle %*% t(H_sum_inv)
  V_beta <- compute_variance_beta(H_tilde_new, p, g_new)

  # Check condition number - skip batch if Hessian is ill-conditioned
  if (kappa(H_tilde_new) > max_cond) {
    stop("Hessian condition number exceeds threshold")
  }

  object$coef <- beta_new
  object$se <- sqrt(pmax(diag(V_beta), 0))
  object$theta <- c(beta_new, lambda_new)
  object$hessian <- H_tilde_new
  object$variance <- V_full_new  # Full (p+g) × (p+g) variance (Formula 2.32)
  object$g <- g_new
  object$n <- object$n + n
  object$breaks <- breaks_new
  object$event_times <- all_event_times
  object
  }, error = function(e) {
    # Skip batch on numerical error (e.g., singular matrix)
    object
  })
}

#' @export
coef.onlinecox <- function(object, ...) object$coef

#' @export
vcov.onlinecox <- function(object, ...) {
  compute_variance_beta(object$hessian, object$p, object$g)
}

#' @export
print.onlinecox <- function(x, ...) {
  cat("Online Cox Model (n=", x$n, ", p=", x$p, ", g=", x$g,
      ", adaptive=", x$adaptive, ")\n", sep = "")
  print(round(cbind(coef = x$coef, se = x$se), 4))
  invisible(x)
}
