library(igraph, warn.conflicts = FALSE)

#' Base constructor for SCM objects
#'
#' This function creates a new SCM object based on the provided directed acyclic
#' graph (DAG), and an ordered list of functions for the nodes in the DAG.
#'
#' @param dag An igraph object representing the DAG structure of the SCM.
#' @param node_functions A list of functions, one for each node in the DAG,
#' following the order returned by igraph::V(dag).
#'
#' @return An object of class "scm".
#'
#' @examples For examples see its subclasses
new_scm <- function(dag, node_functions) {
  stopifnot(igraph::is_dag(dag))
  stopifnot(length(node_functions) == igraph::vcount(dag))
  stopifnot(all(sapply(node_functions, is.function)))

  obj <- list(
    dag = dag,
    node_functions = node_functions
  )
  class(obj) <- "scm"
  return(obj)
}


generate_data <- function(scm, n_samples, seed = NULL) {
  UseMethod("generate_data")
}

#' @export
generate_data_under_intervention <- function(scm, n_samples, interventions, seed = NULL) {
  UseMethod("generate_data_under_intervention")
}


print.scm <- function(x, ...) {
  scm <- x
  dag <- scm$dag
  node_functions <- scm$node_functions

  topo_order <- as.numeric(igraph::topo_sort(dag))

  for (node_idx in topo_order) {
    node_name <- igraph::V(dag)$name[[node_idx]]
    node_func <- node_functions[[node_idx]]
    cat(sprintf("%s <- ", node_name))
    print(node_func)
  }
}


plot.scm <- function(x, ...) {
  plot(x$dag, arrow.size = 0.5, size = 20)
}