## Script to apply the pc algorithm to generated data
## The output CPDAG is enumerated and the cpdag and its containing DAGs are saved

source('R/scm/graph_functions.R')
source('R/scm/node_functions.R')
source('R/scm/binary_scm.R')
source('R/scm/additive_noise_scm.R')
source('R/scm/linear_gaussian_scm.R')

source('R/causal_discovery.R')
source('R/cpdag_enumerate.R')

source('R/experiments/helper_functions.R')

library(pcalg)
library(kpcalg)

apply_pc <- function(data_path, n_observational, alpha, standardize) {
  log_to_shell(paste0("Running PC with ", n_observational, " observational data points, alpha = ", alpha, ", standardized = ", standardize))
  data_and_parameters <- readRDS(data_path)
  data_observational <- data_and_parameters$data_observational
  scm_type <- data_and_parameters$scm_type
  n_nodes <- data_and_parameters$n_nodes
  n_expected_neighbors <- data_and_parameters$n_expected_neighbors
  n_expected_parents_y <- data_and_parameters$n_expected_parents_y
  seed <- data_and_parameters$seed

  # Limit amount of observational data
  data_observational <- data_observational[seq(n_observational), seq(n_nodes - 1)] # Exclude Y

  if (standardize) {
    data_observational <- scale(data_observational)
  }

  # Apply PC algorithm
  if (scm_type == "gaussian") {
    log_to_shell("Using gaussCItest")
    ci_test <- "gaussCItest"
    indepTest <- pcalg::gaussCItest
    suffStat <- list(C = cor(data_observational), n = nrow(data_observational))
  } else if (scm_type == "non-linear1") {
    log_to_shell("Using hsic.gamma")
    ci_test <- "hsic.gamma"
    indepTest <- kpcalg::kernelCItest
    suffStat <- list(data = data_observational, ic.method = "hsic.gamma")
  } else {
    stop("Unknown SCM type")
  }
  log_to_shell("Running PC...")
  pc.fit <- pcalg::pc(suffStat = suffStat,
                      indepTest = indepTest,
                      alpha=alpha, labels = colnames(data_observational),
                      u2pd='retry',
                      verbose = FALSE)
  log_to_shell("Finished running PC. Enumerating DAGs in MEC.")
  # Convert pc.fit output to cpdag and dags
  cpdag_and_dags <- pc_fit_to_dags_in_cpdag(pc.fit)
  cpdag <- cpdag_and_dags$cpdag
  dags <- cpdag_and_dags$dags

  log_to_shell("Finished enumerating DAGs. Saving results.")
  # Save the dags
  save_name <- paste0("pc_", n_observational, "n_alpha", alpha)
  if (standardize) {
    save_name <- paste0(save_name, "_standardized")
  }
  
  save_cpdag(cpdag,
             dags,
             save_name,
             scm_type,
             n_nodes,
             n_expected_neighbors,
             n_expected_parents_y,
             seed,
             data_path = data_path,
             n_observational = n_observational,
             method = "pc",
             alpha = alpha,
             ci_test = ci_test,
             pc.fit = pc.fit,
             
  )
  log_to_shell("Saved results.")
}

args <- commandArgs(trailingOnly = TRUE)

if (length(args) != 4) {
  stop("Usage: Rscript causal_discovery_pc.R <data_path> <n_observational> <alpha> <standardize>")
}

log_to_shell(paste0("causal_discovery_pc.R ", paste(args, collapse = " ")))

data_path <- args[1]
n_observational <- as.numeric(args[2])
alpha <- as.numeric(args[3])
standardize <- as.logical(args[4])

apply_pc(data_path, n_observational, alpha, standardize)