rm(list = ls())

library(tidyverse)
library(gridExtra)

# Initialize results tibble
results <- tibble(
  n_nodes = integer(),
    n_neighbors = integer(),
    n_parents = integer(),
    seed = integer(),
    cd_type = character(),
    SHD_with_oracle = integer()
)


base_dir <- "experiments/5_causal_discovery/non-linear1/"
n_nodes_dirs <- list.files(base_dir, full.names = TRUE)

for (n_nodes_dir in n_nodes_dirs) {
  n_nodes <- as.integer(gsub('([0-9]+)_nodes', '\\1', basename(n_nodes_dir)))
  if (n_nodes == 21) {
    next
  }

  dag_type_dirs <- list.files(n_nodes_dir, full.names = TRUE)
    for (dag_type_dir in dag_type_dirs) {
        n_neighbors <- as.integer(gsub('([0-9]+)_neighbors_[0-9]+_parents_y', '\\1', basename(dag_type_dir)))
        n_parents <- as.integer(gsub('[0-9]+_neighbors_([0-9]+)_parents_y', '\\1', basename(dag_type_dir)))
        for (seed in seq(40)) {
          oracle_cpdag <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_oracle_cpdag.rds")))$cpdag
          pc_cpdag <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_pc_1000n_alpha0.05_standardized_cpdag.rds")))$cpdag
          fges_cpdag <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_fges_1000_standardized_cpdag.rds")))$cpdag

          # remove y node from oracle cpdag
          #oracle_cpdag <- igraph::delete_vertices(oracle_cpdag, "Y")

          oracle_cpdag_graphnel <- igraph::as_graphnel(oracle_cpdag)
          pc_cpdag_graphnel <- igraph::as_graphnel(pc_cpdag)
          fges_cpdag_graphnel <- igraph::as_graphnel(fges_cpdag)


          shd_pc_oracle <- pcalg::shd(oracle_cpdag_graphnel, pc_cpdag_graphnel)
          shd_fges_oracle <- pcalg::shd(oracle_cpdag_graphnel, fges_cpdag_graphnel)

          results <- results %>%
            add_row(n_nodes = n_nodes, n_neighbors = n_neighbors, n_parents = n_parents, seed = seed, cd_type = "pc", SHD_with_oracle = shd_pc_oracle) %>%
            add_row(n_nodes = n_nodes, n_neighbors = n_neighbors, n_parents = n_parents, seed = seed, cd_type = "fges", SHD_with_oracle = shd_fges_oracle)
        }
    }
}

write_delim(results, "R/experiments/results/non-linear1/all_shd.csv")


# Group by n_nodes, n_neighbors, n_parents, cd_type
# Then mean and std over seeds
results_grouped <- results %>%
  group_by(n_nodes, n_neighbors, n_parents, cd_type) %>%
  summarize(
    mean_SHD_with_oracle = mean(SHD_with_oracle),
    median_SHD_with_oracle = median(SHD_with_oracle),
    std_SHD_with_oracle = sd(SHD_with_oracle),
    .groups = "drop"
  )
