## Script to sample SCMs for each of the previously generated DAGs

source('R/scm/graph_functions.R')
source('R/scm/node_functions.R')
source('R/scm/binary_scm.R')
source('R/scm/additive_noise_scm.R')
source('R/scm/linear_gaussian_scm.R')

source('R/experiments/helper_functions.R')

#' This function contains the standard settings we use for the SCMs
sample_scm <- function(dag, type, seed) {
  if (type == "gaussian") {
    lg_weight_function <- function(n) { sample_symmetric_uniform(n, 0.5, 2) }
    lg_noise_variance_function <- function(n) { rep(1, n) }
    scm <- sample_linear_gaussian_scm_given_dag(dag,
                                                lg_weight_function,
                                                lg_noise_variance_function,
                                                standardized = FALSE,
                                                seed)
  } else if (type == "non-linear1") {
    weight_function_parents <- function(n) {sample_symmetric_uniform(n, 0.5, 1.5) }
    weight_function_scale <- function(n) {sample_symmetric_uniform(n, 1, 3)}
    list_of_noise_distributions <- list(distributions3::Normal(0, 1), distributions3::Uniform(-1, 1))
    list_of_node_functions <- list(plogis) # sigmoid
    scm <- sample_weighted_additive_noise_scm_given_dag(dag,
                                                        weight_function_parents,
                                                        weight_function_scale,
                                                        list_of_node_functions,
                                                        list_of_noise_distributions,
                                                        seed)
  }
  return(scm)
}

create_scms_for_each_dag <- function() {
  dag_base_folder <- "experiments/1_dags/"
  dag_folders <- list.dirs(dag_base_folder, recursive = FALSE)
  for (dag_node_folder in dag_folders) {
    dag_type_folders <- list.dirs(dag_node_folder, recursive = FALSE)
    for (dag_type_folder in dag_type_folders) {
      for (file_path in list.files(dag_type_folder, full.names = TRUE)) {
        if (grepl(".rds$", file_path)) {
          dag_list <- readRDS(file_path)
          log_to_shell(paste0("Creating SCMs for ", file_path))
          for (scm_type in c("gaussian", "non-linear1")) {
            log_to_shell(paste0("Creating ", scm_type, " SCMs for ", file_path))
            scm <- sample_scm(dag_list$dag, scm_type, dag_list$seed)
            save_scm(scm, scm_type, file_path, dag_list$n_nodes, dag_list$n_expected_neighbors, dag_list$n_expected_parents_y, dag_list$seed)
          }
        }
      }
    }
  }
}


create_scms_for_each_dag()