rm(list = ls())
library(foreach)
library(parallel)
library(doSNOW)
library(survival)
library(COLSA)

# Load the R implementation of online method
source("online/online_update.R")

if (requireNamespace("coxphSGD", quietly = TRUE)) {
  library(coxphSGD)
}

args <- commandArgs(trailingOnly = TRUE)
K <- ifelse(length(args) >= 1, as.numeric(args[1]), 6)
n_small <- ifelse(length(args) >= 2, as.numeric(args[2]), 100)
method <- ifelse(length(args) >= 3, args[3], "colsa")

n_replication <- 500

# Parallel computing setup
n_cores <- max(1, detectCores() - 1)
cl <- makeCluster(n_cores)
registerDoSNOW(cl)
on.exit(
  {
    stopCluster(cl)
    close(pb)
  },
  add = TRUE
)

# Progress bar
pb <- txtProgressBar(max = n_replication, style = 3)
progress <- function(n) setTxtProgressBar(pb, n)
opts <- list(progress = progress)


# Parallel simulation
results <- foreach(
  i = seq_len(n_replication),
  .packages = c("survival", "COLSA", "adagio", "coxphSGD"),
  .options.snow = opts,
  .errorhandling = "remove"
) %dopar% {
  set.seed(i)
  # Load online method in worker
  source("online/online_update.R")
  data <- read.csv(paste0("data/sim_K=", K, "_n=", n_small, "_seed=", i, ".csv"))
  covariate_names <- names(data)[4:9]
  form <- as.formula(
    paste("Surv(time, status) ~", paste(covariate_names, collapse = " + "))
  )
  start_time <- Sys.time()
  if (method == "oracle") {
    fit <- coxph(formula = form, data = data)
    elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
    as.vector(c(i, elapsed, fit$coefficients, sqrt(diag(fit$var))))
  } else if (method == "colsa") {
    boundary <- c(0, max(data$time))
    df_sub <- subset(data, group == 1)
    aics <- sapply(seq_len(5), function(n_basis) {
      AIC(colsa(form, df_sub, n_basis, boundary, scale = 1))
    })
    n_basis_best <- which.min(aics)
    alpha_best <- n_basis_best / nrow(df_sub)^0.2

    fit <- colsa(form, df_sub, n_basis_best, boundary)
    for (batch in 2:K) {
      df_sub <- data[data$group == batch, , drop = FALSE]
      fit <- update(fit, df_sub, alpha = alpha_best)
    }
    elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
    n_basis_final <- tail(fit$n_basis, 1)
    res <- c(i, elapsed, n_basis_final, coef(fit), sqrt(diag(vcov(fit))))
    as.vector(res)
  } else if (startsWith(method, "colsa_fixed")) {
    n_basis_fixed <- as.numeric(sub("colsa_fixed_", "", method))
    boundary <- c(0, max(data$time))
    df_sub <- subset(data, group == 1)
    fit <- colsa(form, df_sub, n_basis_fixed, boundary, scale = 1)
    for (batch in 2:K) {
      df_sub <- data[data$group == batch, , drop = FALSE]
      fit <- update(fit, df_sub, n_basis = n_basis_fixed)
    }
    elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
    res <- c(i, elapsed, n_basis_fixed, coef(fit), sqrt(diag(vcov(fit))))
    as.vector(res)
  } else if (method == "meta") {
    coefs <- list()
    vars <- list()
    n_valid <- 0
    for (k in seq_len(K)) {
      df_sub <- subset(data, group == k)
      # Skip batch if too few events
      if (sum(df_sub$status) < 20) next
      fit <- tryCatch(
        suppressWarnings(coxph(formula = form, data = df_sub)),
        error = function(e) NULL
      )
      if (is.null(fit)) next
      n_valid <- n_valid + 1
      coefs[[length(coefs) + 1]] <- fit$coefficients
      vars[[length(vars) + 1]] <- fit$var
    }
    # Compute inverse covariance matrices and weighted coefficients
    var_invs <- lapply(vars, solve)
    var_sum <- Reduce("+", var_invs)
    coef_sum <- Reduce("+", mapply(
      function(coef, var_inv) var_inv %*% coef, coefs, var_invs,
      SIMPLIFY = FALSE
    ))
    coef_meta <- solve(var_sum, coef_sum)
    se_meta <- sqrt(diag(solve(var_sum)))
    elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
    # Return: seed, time, n_valid, coef, se
    as.vector(c(i, elapsed, n_valid, coef_meta, se_meta))
  } else if (method == "online_bc") {
    # Online Cox method (Wu et al. 2021) with bias correction
    # Using onlinecox() implementation with full Formula (2.32) variance
    #
    # Paper: "Online Updating of Survival Analysis"
    # Journal of Computational and Graphical Statistics
    # DOI: 10.1080/10618600.2020.1870481
    #
    # NOTE: Batches with < min_events events are skipped to ensure stability.

    min_events <- 20L

    # Initialize with first batch
    df_sub <- subset(data, group == 1)
    fit <- onlinecox(form, df_sub, n_groups = 5L, adaptive = TRUE, max_groups = 20L)

    # Update with remaining batches, tracking skipped batches
    n_used <- 1
    for (batch in 2:K) {
      df_sub <- subset(data, group == batch)
      n_before <- fit$n
      fit <- update(fit, df_sub, r_k = 2, min_events = min_events)
      if (fit$n > n_before) n_used <- n_used + 1
    }

    elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
    # Return: seed, time, n_used, g, coef, se
    as.vector(c(i, elapsed, n_used, fit$g, coef(fit), fit$se))
  } else if (method == "sgd") {
    # Online SGD Cox - using coxphSGD package
    data_list <- split(data, data$group)
    p <- length(covariate_names)

    # Initialize with first batch
    beta_init <- tryCatch({
      coef(coxph(form, data = data_list[[1]]))
    }, error = function(e) rep(0, p))

    # Learning rate: 1 / (max(100, 500/sqrt(K)) * t^0.5)
    c_lr <- max(100, 500 / sqrt(K))
    fit <- coxphSGD(
      formula = form,
      data = data_list,
      epsilon = 0,
      learn.rates = function(t) 1 / (c_lr * t^0.5),
      beta.zero = beta_init,
      max.iter = K
    )

    coef_sgd <- tail(fit$coefficients, 1)[[1]]

    # Compute standard errors using cumulative information matrix
    info_cum <- matrix(0, p, p)
    beta_history <- fit$coefficients
    for (k in 1:K) {
      beta_k <- if (k == 1) rep(0, p) else beta_history[[k]]
      fit_k <- coxph(form, data = data_list[[k]], init = beta_k, iter.max = 0)
      info_cum <- info_cum + solve(fit_k$var)
    }
    se_sgd <- sqrt(diag(solve(info_cum)))

    elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
    as.vector(c(i, elapsed, coef_sgd, se_sgd))
  } else if (method == "sgd_offline") {
    # Offline SGD Cox - allows multiple passes over the data (K*20 iterations)
    data_list <- split(data, data$group)
    p <- length(covariate_names)
    n_epochs <- 100

    # Initialize with first batch
    beta_init <- tryCatch({
      coef(coxph(form, data = data_list[[1]]))
    }, error = function(e) rep(0, p))

    # Create repeated data list for multiple epochs
    data_list_repeated <- rep(data_list, n_epochs)

    # Learning rate: 1 / (max(100, 500/sqrt(K)) * t^0.5)
    c_lr <- max(100, 500 / sqrt(K))
    fit <- coxphSGD(
      formula = form,
      data = data_list_repeated,
      epsilon = 0,
      learn.rates = function(t) 1 / (c_lr * t^0.5),
      beta.zero = beta_init,
      max.iter = K * n_epochs
    )

    coef_sgd <- tail(fit$coefficients, 1)[[1]]

    # Compute standard errors using cumulative information matrix (one pass)
    info_cum <- matrix(0, p, p)
    beta_final <- coef_sgd
    for (k in 1:K) {
      fit_k <- coxph(form, data = data_list[[k]], init = beta_final, iter.max = 0)
      info_cum <- info_cum + solve(fit_k$var)
    }
    se_sgd <- sqrt(diag(solve(info_cum)))

    elapsed <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
    as.vector(c(i, elapsed, coef_sgd, se_sgd))
  }
}

close(pb)

if (method == "colsa") {
  # Unify result length, pad with NA
  length_max <- max(sapply(results, length))
  df <- as.data.frame(do.call(rbind, lapply(results, function(x) {
    if (length(x) < length_max) x <- c(x, rep(NA, length_max - length(x)))
    x
  })))
  rownames(df) <- NULL
} else {
  df <- as.data.frame(do.call(rbind, results))
}

results_dir <- "results"
if (!dir.exists(results_dir)) {
  dir.create(results_dir, recursive = TRUE, showWarnings = FALSE)
}
result_file <- file.path(
  results_dir,
  sprintf("sim_K=%d_n=%d_method=%s.csv", K, n_small, method)
)
write.csv(df, file = result_file, row.names = FALSE)
