rm(list = ls())
library(simsurv)
library(MASS)
library(foreach)
library(parallel)
library(doSNOW)

# Simulation function
sim_mix_weibul <- function(N, lambdas, gammas, pmix, beta, rateC) {
  Sigma <- matrix(c(10, 3, 3, 2), nrow = 2)
  cov <- as.data.frame(mvrnorm(n = N, mu = c(5, 5), Sigma = Sigma))
  names(cov) <- c("x1", "x2")
  cov$x3 <- factor(rbinom(N, 1, prob = 0.8))
  cov$x4 <- NA
  cov$x4[cov$x3 == 0] <- sample(
    1:4, sum(cov$x3 == 0),
    replace = TRUE,
    prob = c(0.2, 0.2, 0.3, 0.3)
  )
  cov$x4[cov$x3 == 1] <- sample(
    1:4, sum(cov$x3 == 1),
    replace = TRUE,
    prob = c(0.1, 0.2, 0.4, 0.5)
  )
  cov$x4 <- factor(cov$x4)
  design_matrix <- model.matrix(~ x1 + x2 + x3 + x4, data = cov)[, -1]
  covariate_data <- data.frame(id = 1:N, design_matrix)
  s3 <- simsurv(
    lambdas = lambdas,
    gammas = gammas,
    interval = c(1e-8, 10000),
    betas = setNames(beta, colnames(design_matrix)),
    mixture = TRUE,
    pmix = pmix,
    x = covariate_data
  )
  T <- s3$eventtime
  C <- rexp(N, rate = rateC)
  time <- pmin(T, C)
  status <- as.numeric(T <= C)
  data.frame(
    id = 1:N,
    time = time,
    status = status,
    covariate_data[, -1]
  )
}

# Parameter settings
args <- commandArgs(trailingOnly = TRUE)
K <- ifelse(length(args) >= 1, as.numeric(args[1]), 6)
n_small <- ifelse(length(args) >= 2, as.numeric(args[2]), 100)

rateC <- 6
ns <- c(1500, 1500, 1500, 500, 500, 500, rep(n_small, K - 6))
n_replication <- 500
N_total <- sum(ns)


# Parallel computing setup
n_cores <- max(1, detectCores() - 1)
cl <- makeCluster(n_cores)
registerDoSNOW(cl)
on.exit({
  stopCluster(cl)
  close(pb)
}, add = TRUE)

# Progress bar
pb <- txtProgressBar(max = n_replication, style = 3)
progress <- function(n) setTxtProgressBar(pb, n)
opts <- list(progress = progress)

data_dir <- "data"
if (!dir.exists(data_dir)) {
  dir.create(data_dir, recursive = TRUE)
}

# Parallel simulation
results <- foreach(
  i = seq_len(n_replication),
  .packages = c("simsurv", "MASS"),
  .options.snow = opts
) %dopar% {
  set.seed(i)
  sim_data <- sim_mix_weibul(
    N = N_total,
    lambdas = c(10, 20),
    gammas = c(3, 5),
    pmix = 0.5,
    beta = c(0.15, -0.15, 0.3, 0.3, 0.3, 0.3),
    rateC = rateC
  )
  sim_data$group <- rep(1:K, times = ns)
  out_file <- file.path(data_dir, sprintf("sim_K=%d_n=%d_seed=%d.csv", K, n_small, i))
  write.csv(sim_data, file = out_file, row.names = FALSE)
}

close(pb)