## Script to convert a DAG to a CPDAG and enumerate all possible DAGs in the CPDAG
## Output is saved in 5_causal_discovery in the format seed_oracle_cpdag.rds


source('R/scm/graph_functions.R')
source('R/scm/node_functions.R')
source('R/scm/binary_scm.R')
source('R/scm/additive_noise_scm.R')
source('R/scm/linear_gaussian_scm.R')
source('R/scm/graph_functions.R')

source('R/causal_discovery.R')
source('R/cpdag_enumerate.R')

source('R/experiments/helper_functions.R')

library(pcalg)

oracle_for_dag <- function(dag_path) {
  log_to_shell(paste0("Creating oracles for ", dag_path, "..."))
  dag_list <- readRDS(dag_path)
  dag <- dag_list$dag
  n_nodes <- dag_list$n_nodes
  seed <- dag_list$seed
  n_expected_neighbors <- dag_list$n_expected_neighbors
  n_expected_parents_y <- dag_list$n_expected_parents_y

  # Use pcalg to convert dag to cpdag
  # Remove Y node
  dag <- igraph::delete_vertices(dag, "Y")
  cpdag <- pcalg::dag2cpdag(igraph::as_graphnel(dag)) # graph object
  cpdag_adj <- graph::adjacencyMatrix(as(cpdag, "graphBAM")) # to adjacency matrix in form [a,b] = 1 if a->b
  cpdag_and_dags <- adj_mat_cpdag_to_dags_in_cpdag(cpdag_adj)
  cpdag <- cpdag_and_dags$cpdag
  dags <- cpdag_and_dags$dags

  log_to_shell(paste0("Saving oracles for ", dag_path, "..."))
  # Save the dags
  save_cpdag(cpdag, dags, "oracle", "gaussian", n_nodes, n_expected_neighbors, n_expected_parents_y, seed,
             true_dag_path = dag_path)
  save_cpdag(cpdag, dags, "oracle", "non-linear1", n_nodes, n_expected_neighbors, n_expected_parents_y, seed,
             true_dag_path = dag_path)
  log_to_shell(paste0("Done saving oracles for ", dag_path, "."))
}

oracle_for_all_dags <- function() {
  dag_base_folder <- "experiments/1_dags/"
  dag_folders <- list.dirs(dag_base_folder, recursive = FALSE)
  for (dag_node_folder in dag_folders) {
    dag_type_folders <- list.dirs(dag_node_folder, recursive = FALSE)
    for (dag_type_folder in dag_type_folders) {
      for (file_path in list.files(dag_type_folder, full.names = TRUE)) {
        if (grepl(".rds$", file_path)) {
            oracle_for_dag(file_path)
        }
      }
    }
  }
}

# oracle_for_all_dags()

args <- commandArgs(trailingOnly = TRUE)
dag_path <- args[1]
log_to_shell(paste0("causal_discovery_oracle.R ", dag_path))
oracle_for_dag(dag_path)