rm(list = ls())

library(tidyverse)
library(igraph)

n_seeds <- 40
n_cd <- 1000

base_dir <- 'experiments/5_causal_discovery/non-linear1/'

# Tibble to store results
# n_nodes, n_neighbors, n_parents, cd_type, correct
results <- tibble(
  path = character(),
  n_nodes = integer(),
  n_neighbors = integer(),
  n_parents = integer(),
  seed = integer(),
  cd_type = character(),
  correct = logical()
)


node_folders <- list.files(base_dir, full.names = TRUE)
for (node_folder in node_folders) {
  n_nodes_folder_base <- basename(node_folder)
  n_nodes <- as.integer(gsub('([0-9]+)_nodes', '\\1', n_nodes_folder_base))
  if (n_nodes == 21) {
    next
  }
  for (dag_folder in list.files(node_folder, full.names = TRUE)) {
    dag_folder_base <- basename(dag_folder)
    n_neighbors <- as.integer(gsub('([0-9]+)_neighbors_[0-9]+_parents_y', '\\1', dag_folder_base))
    n_parents <- as.integer(gsub('[0-9]+_neighbors_([0-9]+)_parents_y', '\\1', dag_folder_base))
    files <- list.files(dag_folder, full.names = TRUE)
    for (seed in seq(n_seeds)) {
      oracle_path <- file.path(dag_folder, paste0(seed, '_oracle_cpdag.rds'))
      if (!file.exists(oracle_path)) {
        next
      }
      oracle_cpdag_and_parameters <- readRDS(oracle_path)
      oracle_cpdag <- oracle_cpdag_and_parameters$cpdag
      oracle_cpdag_adj <- as_adjacency_matrix(oracle_cpdag, sparse = FALSE)

      # Loop over all files that have the format seed_pc_[0-9]+n_alpha0.05_cpdag.rds
      seed_pc_files <- files[grep(paste0(seed, '_pc_[0-9]+n_alpha0.05_cpdag.rds'), files)]
      for (seed_pc_file in seed_pc_files) {
            pc_path <- file.path(seed_pc_file)
            pc_cpdag_and_parameters <- readRDS(pc_path)
            pc_cpdag <- pc_cpdag_and_parameters$cpdag
            pc_cpdag_adj <- as_adjacency_matrix(pc_cpdag, sparse = FALSE)
            pc_equal_to_oracle <- identical(pc_cpdag_adj, oracle_cpdag_adj)
            results <- results %>%
                add_row(
                path = pc_path,
                n_nodes = n_nodes,
                n_neighbors = n_neighbors,
                n_parents = n_parents,
                seed = seed,
                cd_type = 'pc',
                correct = pc_equal_to_oracle
                )
        }



    }
  }
}


write_delim(results, 'R/experiments/results/non-linear1/cd_non-linear1_correctness.csv', delim = ',')

# Write paths of incorrect results to file
incorrect_results <- results %>%
  filter(!correct) %>%
  select(path)

write_delim(incorrect_results, 'R/experiments/results/non-linear1/cd_non-linear1_incorrect_paths.csv', delim = ',')