rm(list = ls())

library(tidyverse)
library(gridExtra)

# Initialize results tibble
results <- tibble(
  n_nodes = integer(),
    n_neighbors = integer(),
    n_parents = integer(),
    seed = integer(),
    cd_type = character(),
    mec_size = integer()
)


base_dir <- "experiments/5_causal_discovery/non-linear1/"
n_nodes_dirs <- list.files(base_dir, full.names = TRUE)

for (n_nodes_dir in n_nodes_dirs) {
  n_nodes <- as.integer(gsub('([0-9]+)_nodes', '\\1', basename(n_nodes_dir)))
  if (n_nodes == 21) {
    next
  }

  dag_type_dirs <- list.files(n_nodes_dir, full.names = TRUE)
    for (dag_type_dir in dag_type_dirs) {
        n_neighbors <- as.integer(gsub('([0-9]+)_neighbors_[0-9]+_parents_y', '\\1', basename(dag_type_dir)))
        n_parents <- as.integer(gsub('[0-9]+_neighbors_([0-9]+)_parents_y', '\\1', basename(dag_type_dir)))
        for (seed in seq(40)) {
          oracle <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_oracle_cpdag.rds")))
          pc <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_pc_1000n_alpha0.05_standardized_cpdag.rds")))
          fges <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_fges_1000_standardized_cpdag.rds")))

          n_dags_oracle <- length(oracle$dags)
          n_dags_pc <- length(pc$dags)
          n_dags_fges <- length(fges$dags)

          results <- results %>%
            add_row(n_nodes = n_nodes, n_neighbors = n_neighbors, n_parents = n_parents, seed = seed, cd_type = "oracle", mec_size = n_dags_oracle) %>%
            add_row(n_nodes = n_nodes, n_neighbors = n_neighbors, n_parents = n_parents, seed = seed, cd_type = "pc", mec_size = n_dags_pc) %>%
            add_row(n_nodes = n_nodes, n_neighbors = n_neighbors, n_parents = n_parents, seed = seed, cd_type = "fges", mec_size = n_dags_fges)
        }
    }
}

write_delim(results, "R/experiments/results/non-linear1/all_mec_sizes.csv")

# Analysis
results <- read_delim("R/experiments/results/non-linear1/all_mec_sizes.csv")

# Group by n_nodes, n_neighbors, n_parents, cd_type
# Then mean and std over seeds
results_grouped <- results %>%
  group_by(n_nodes, n_neighbors, n_parents, cd_type) %>%
  summarize(
    mean_mec_size = mean(mec_size),
    median_mec_size = median(mec_size),
    std_mec_size = sd(mec_size),
    .groups = "drop"
  )
















#
#
# # Combine with results gaussian base
# tibble_gaussian <- read_delim("R/experiments/results/gaussian/results_nodes.csv")
# tibble_gaussian <- tibble_gaussian %>%
#   filter(n_nodes == 11) %>%
#   filter(explanation_type %in% c("pc_sample", "fges_sample"))
#
# # Add cd_type to tibble_gaussian
# tibble_gaussian <- tibble_gaussian %>%
#   mutate(cd_type = ifelse(explanation_type == "pc_sample", "pc", "fges"))
#
# # Join the two tibbles based on (seed, cd_type)
# test <- tibble_gaussian %>%
#   left_join(results, by = c("seed", "cd_type")) %>%
#   select(-n_nodes)
#
# test2 <- test %>%
#   group_by(seed, explanation_type) %>%
#   summarize(
#     average_l2_error = mean(l2_error),
#     SHD_with_oracle = first(SHD_with_oracle)
#   ) %>%
#   ungroup()
#
# # We want to plot SHD with oracle vs L2 error
# # We first average over MECs then over datapoints
# test3 <- test %>%
#   group_by(explanation_type, seed, idx_datapoint, cd_type, SHD_with_oracle) %>%
#   summarise(l2_error_over_mec = mean(l2_error), .groups = "drop") %>%
#   group_by(explanation_type, seed, cd_type,SHD_with_oracle) %>%
#   summarise(l2_error_over_datapoints = mean(l2_error_over_mec), .groups = "drop")
#
# ggplot(test2, aes(x = SHD_with_oracle, y = average_l2_error, color = explanation_type)) +
#   geom_point() +
#   labs(
#     title = "SHD with Oracle vs L2 Error",
#     x = "SHD with Oracle",
#     y = "Average L2 Error",
#     color = "CD Type"
#   ) +
#   theme_minimal()
#
# # Min l2 error over MECs
# test4 <- test %>%
#   group_by(explanation_type, seed, idx_datapoint, cd_type, SHD_with_oracle) %>%
#   summarise(l2_error_over_mec = mean(l2_error), .groups = "drop") %>%
#   group_by(explanation_type, seed, cd_type,SHD_with_oracle) %>%
#   summarise(l2_error_over_datapoints = min(l2_error_over_mec), .groups = "drop")