rm(list = ls())

library(tidyverse)
library(gridExtra)

# # Initialize results tibble
# results <- tibble(
#   seed = integer(),
#   cd_type = character(),
#   SHD_with_oracle = integer()
# )


# base_dir <- "experiments/5_causal_discovery/non-linear1/11_nodes/2_neighbors_6_parents_y"
# for (seed in seq(40)) {
#   oracle_cpdag <- readRDS(paste0(base_dir, "/", paste0(seed, "_oracle_cpdag.rds")))$cpdag
#   pc_cpdag <- readRDS(paste0(base_dir, "/", paste0(seed, "_pc_1000n_alpha0.05_cpdag.rds")))$cpdag
#   fges_cpdag <- readRDS(paste0(base_dir, "/", paste0(seed, "_fges_1000_cpdag.rds")))$cpdag

#   # remove y node from oracle cpdag
#   oracle_cpdag <- igraph::delete_vertices(oracle_cpdag, "Y")

#   oracle_cpdag_graphnel <- igraph::as_graphnel(oracle_cpdag)
#   pc_cpdag_graphnel <- igraph::as_graphnel(pc_cpdag)
#   # fges_cpdag_graphnel <- igraph::as_graphnel(fges_cpdag)


#   shd_pc_oracle <- pcalg::shd(oracle_cpdag_graphnel, pc_cpdag_graphnel)
#   # shd_fges_oracle <- pcalg::shd(oracle_cpdag_graphnel, fges_cpdag_graphnel)

#   results <- results %>%
#     add_row(seed = seed, cd_type = "pc", SHD_with_oracle = shd_pc_oracle) #%>%
#     # add_row(seed = seed, cd_type = "fges", SHD_with_oracle = shd_fges_oracle)
# }

# write_delim(results, "R/experiments/results/non-linear1/base_shd.csv")

# # Initialize results tibble
results <- tibble(
  n_nodes = integer(),
    n_neighbors = integer(),
    n_parents = integer(),
    seed = integer(),
    cd_type = character(),
    SHD_with_oracle = integer()
)


base_dir <- "experiments/5_causal_discovery/non-linear1/"
n_nodes_dirs <- list.files(base_dir, full.names = TRUE)

for (n_nodes_dir in n_nodes_dirs) {
  n_nodes <- as.integer(gsub('([0-9]+)_nodes', '\\1', basename(n_nodes_dir)))
  if (n_nodes == 21) {
    next
  }

  dag_type_dirs <- list.files(n_nodes_dir, full.names = TRUE)
    for (dag_type_dir in dag_type_dirs) {
        n_neighbors <- as.integer(gsub('([0-9]+)_neighbors_[0-9]+_parents_y', '\\1', basename(dag_type_dir)))
        n_parents <- as.integer(gsub('[0-9]+_neighbors_([0-9]+)_parents_y', '\\1', basename(dag_type_dir)))

        if (n_nodes == 11 & (n_neighbors != 2 | n_parents != 6)) {
          next
        }

        for (seed in seq(40)) {
          oracle_cpdag <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_oracle_cpdag.rds")))$cpdag
          pc_cpdag <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_pc_1000n_alpha0.05_cpdag.rds")))$cpdag
          fges_cpdag <- readRDS(paste0(dag_type_dir, "/", paste0(seed, "_fges_1000_cpdag.rds")))$cpdag

          # remove y node from oracle cpdag
          oracle_cpdag <- igraph::delete_vertices(oracle_cpdag, "Y")

          oracle_cpdag_graphnel <- igraph::as_graphnel(oracle_cpdag)
          pc_cpdag_graphnel <- igraph::as_graphnel(pc_cpdag)
          fges_cpdag_graphnel <- igraph::as_graphnel(fges_cpdag)


          shd_pc_oracle <- pcalg::shd(oracle_cpdag_graphnel, pc_cpdag_graphnel)
          shd_fges_oracle <- pcalg::shd(oracle_cpdag_graphnel, fges_cpdag_graphnel)

          results <- results %>%
            add_row(n_nodes = n_nodes, n_neighbors = n_neighbors, n_parents = n_parents, seed = seed, cd_type = "pc", SHD_with_oracle = shd_pc_oracle) %>%
            add_row(n_nodes = n_nodes, n_neighbors = n_neighbors, n_parents = n_parents, seed = seed, cd_type = "fges", SHD_with_oracle = shd_fges_oracle)
        }
    }
}

write_delim(results, "R/experiments/results/non-linear1/all_shd_fges.csv")


# Group by n_nodes, n_neighbors, n_parents, cd_type
# Then mean and std over seeds
results_grouped <- results %>%
  group_by(n_nodes, n_neighbors, n_parents, cd_type) %>%
  summarize(
    mean_SHD_with_oracle = mean(SHD_with_oracle),
    median_SHD_with_oracle = median(SHD_with_oracle),
    std_SHD_with_oracle = sd(SHD_with_oracle),
    .groups = "drop"
  )

















# # Combine with results gaussian base
# tibble_gaussian <- read_delim("R/experiments/results/gaussian/results_nodes.csv")
# tibble_gaussian <- tibble_gaussian %>%
#   filter(n_nodes == 11) %>%
#   filter(explanation_type %in% c("pc_sample", "fges_sample"))

# # Add cd_type to tibble_gaussian
# tibble_gaussian <- tibble_gaussian %>%
#   mutate(cd_type = ifelse(explanation_type == "pc_sample", "pc", "fges"))

# # Join the two tibbles based on (seed, cd_type)
# test <- tibble_gaussian %>%
#   left_join(results, by = c("seed", "cd_type")) %>%
#   select(-n_nodes)

# test2 <- test %>%
#   group_by(seed, explanation_type) %>%
#   summarize(
#     average_l2_error = mean(l2_error),
#     SHD_with_oracle = first(SHD_with_oracle)
#   ) %>%
#   ungroup()

# # We want to plot SHD with oracle vs L2 error
# # We first average over MECs then over datapoints
# test3 <- test %>%
#   group_by(explanation_type, seed, idx_datapoint, cd_type, SHD_with_oracle) %>%
#   summarise(l2_error_over_mec = mean(l2_error), .groups = "drop") %>%
#   group_by(explanation_type, seed, cd_type,SHD_with_oracle) %>%
#   summarise(l2_error_over_datapoints = mean(l2_error_over_mec), .groups = "drop")

# ggplot(test2, aes(x = SHD_with_oracle, y = average_l2_error, color = explanation_type)) +
#   geom_point() +
#   labs(
#     title = "SHD with Oracle vs L2 Error",
#     x = "SHD with Oracle",
#     y = "Average L2 Error",
#     color = "CD Type"
#   ) +
#   theme_minimal()

# # Min l2 error over MECs
# test4 <- test %>%
#   group_by(explanation_type, seed, idx_datapoint, cd_type, SHD_with_oracle) %>%
#   summarise(l2_error_over_mec = mean(l2_error), .groups = "drop") %>%
#   group_by(explanation_type, seed, cd_type,SHD_with_oracle) %>%
#   summarise(l2_error_over_datapoints = min(l2_error_over_mec), .groups = "drop")