library(data.table)
library(purrr)
library(cmprsk)
library(survival)


# Single site competing risk Fine and Gray model

build_cr_fg_model <- function(MM, continuous_covs, factor_covs, Trt = "Trt_Metformin",
                              time_col = "Time_CR", event_col = "CR_Event",
                              variant_name = "", serialize = F, break_ties = F, 
                              object_dir = "Models") {
  
  
  if (break_ties) {
    unique_times <- unique(MM[, get(time_col)])
    if(length(unique_times) < length(MM[, get(time_col)])) {
      for(i in 1:length(unique_times)){
        index_ties_i <- which(MM[, get(time_col)] == unique_times[i])
        small_noise <- rnorm(length(index_ties_i), sd=0.01)
        MM[index_ties_i, (time_col) := get(time_col) + small_noise]
      }
    }
  }
  
  
  # Set up factor covaraites
  factor_list <- list()
  for (fc in c(Trt, factor_covs)) {
    MM[, (paste0(fc, "_factor")) := factor(get(fc))]
    factor_list[[paste0(fc, "_factor")]] <- MM[, get(paste0(fc, "_factor"))]
  }
  
  if (length(factor_covs) > 0) {
    factors <- paste0(factor_covs, "_factor")
  } else {
    factors <- c()
  }
  
  
  # Build model formula
  cov_formula_str <- paste0(Trt, "_factor")
  if (length(factor_covs > 0)) {
    cov_formula_str <- paste0(cov_formula_str, " + ",
                              paste0(factors, collapse = " + "))
  }
  if (length(continuous_covs) > 0) {
    cov_formula_str <- paste0(cov_formula_str, " + ",
                              paste0(continuous_covs, collapse = " + "))
  }
  

  model_spec <- reformulate(cov_formula_str) 
  
  covariates_matrix <- model.matrix(
    model_spec,
    data = MM,
    contrasts.arg = lapply(factor_list, contrasts)
  )[, -1]
  
  # Run competing risks regression
  cr_fg_model <- 
    crr(ftime = MM[, get(time_col)], fstatus = MM[, get(event_col)],
        covariates_matrix,
        failcode = 1, cencode = 0)
  
  if(serialize) {
    object_name <- paste0("CR_FG_Model_",variant_name)
    saveRDS(cr_fg_model, file = paste0(object_dir, object_name,".rds"))
  }
  
  return(cr_fg_model)
}




# Meta analysis

# Prepare data
get_model_HR_data <- function(model, model_name) {
  coefs <- summary(model)$coef
  HR_dt <- data.table(
    Study = c(model_name),
    HR = c(coefs[1,2]),
    SE = c(coefs[1,3])
  )
  return(HR_dt)
}

build_HR_table <- function(models) {
  HR_data <- imap_dfr(models, get_model_HR_data)
  return(HR_data)
}


# Run meta analysis
get_meta_HR <- function(HR_data, RE = 0) {
  
  HR_data[, weight := 1/(SE ^ 2)]
  if (RE == 1) {
    z_bar <- HR_data[, sum(HR*(weight/sum(weight)))]
    Q <- HR_data[, sum(weight * (HR - z_bar) ^ 2)]
    tau_sq <- ifelse(Q < nrow(HR_data) - 1, 0, 
                     (Q - nrow(HR_data) + 1) / (HR_data[, sum(weight)] - (HR_data[, sum(weight^2)] / HR_data[, sum(weight)])))
    
    HR_data[, weight := 1 / ((SE ^ 2) + (tau_sq ^ 2))]
  }
  
  HR_meta <- HR_data[, sum(HR*weight)]/HR_data[,sum(weight)]
  SE_meta <- sqrt(1/(HR_data[, sum(weight)]))
  
  return(list(HR = HR_meta, SE = SE_meta,
              CIU = HR_meta + (1.96 * SE_meta),
              CIL = HR_meta - (1.96 * SE_meta)))
}





# Federation

# Prepare data
build_fl_dt <- function(MM, covs, main_site, 
                        id_col, site_col,
                        n_sample = 0, seed = -1) {
  
  # Optionally sub-sample
  if (n_sample > 0) {
    if (seed >= 0) {
      set.seed(seed)
    }
    sampled_rows <- sample(1:nrow(MM), n_sample, replace = F)
    MM_sample <- MM[sampled_rows]
  } else {
    MM_sample <- MM
  }
  
  site_mapping <- MM_sample[, .N, by=.(site_name = get(site_col))]
  
  fl_dt <-
    MM_sample[,
              .(patient_num = id_col, site = get(site_col),
                id.site = ifelse(get(site_col) == main_site, 1, 2),
                t_surv = get(time_col),
                type = CR_Event,
                censor = ifelse(CR_Event == 0, 0, 1))]
  
  fl_dt <- fl_dt[MM_sample[, .SD[], 
                           .SDcols = c(id_col, "site", covs)],
                 on=c(id_col, "site")]
  
  setnames(fl_dt,
           covs,
           paste0("z", 1:length(covs)))
  
  fl_dt[, (id_col) := NULL]
  fl_dt[, site := NULL]
  
  
  # Break Ties
  unique_times=unique(fl_dt$t_surv)
  if(length(unique_times) < length(fl_dt$t_surv)) {
    for(i in 1:length(unique_times)){
      index_ties_i <- which(fl_dt$t_surv == unique_times[i])
      small_noise <- abs(rnorm(length(index_ties_i), sd=0.01))
      fl_dt$t_surv[index_ties_i] <- fl_dt$t_surv[index_ties_i] + small_noise
    }
  }
  
  
  
  return(fl_dt)
}




# Run federation
# This functions is dependent on code from the following paper
# Citation:
# Zhang, Dazheng, et al. 
# "Learning competing risks across multiple hospitals: one-shot distributed algorithms." 
# Journal of the American Medical Informatics Association 31.5 (2024): 1102-1112.

# Github: https://github.com/Penncil/ODACoR

federate_MM <- function(all_sites) {
  
  n_sites <- length(unique(all_sites$id.site))
  
  #meta analysis estimator
  beta_bar <- get_meta_est(all_sites, n_sites)

  initialize_results<-d_initialize(all_sites, n_sites)
  T_all <- initialize_results$T_all
  fit <- initialize_results$fit
  
  for(i in 1:n_sites){
    # check if there is type 1 event in that set.
    if(sum(all_sites[all_sites$id.site==i,]$type==1)>=1){
      data_site_i <- all_sites[all_sites$id.site==i,]
      break
    }
  }
  
  # calculate the summary-level statistics for each site 
  b <- d_distribute(all_sites,beta_bar,T_all,fit)
  n_list <- b$n_list
  cov_sum_list <- b$cov_sum_list
  U_list <- b$U_list
  W_list <- b$W_list
  Z_list <- b$Z_list
  
  c <- d_assemble(n_list,cov_sum_list,U_list,W_list,Z_list)
  global_first_bbar <- c$global_first_bbar
  global_second_bbar <- c$global_second_bbar
  
  # One-step estimator
  sol_ODACoR_O <- beta_bar-solve(global_second_bbar)%*%global_first_bbar
  return(sol_ODACoR_O)
}



