source("R/scm/base_scm.R")

#' Constructor for binary SCM objects
#'
#' This function creates a new binary SCM object based on the provided directed acyclic
#' graph (DAG), and an ordered list of functions for the nodes in the DAG.
#' The way binary SCMs as implemented here differ from the other SCMs is that each
#' structural equation is equivalent to a Bernoulli distribution, where its parameter is
#' determined by its parent, and the uncertainty arises from the inherent uncertainty in
#' sampling from a Bernoulli. See the functions below for details on how data is generated.
#' Each node function returns p as a function of its parents.
#'
#' @param dag An igraph object representing the DAG structure of the SCM.
#' @param node_functions A list of functions, one for each node in the DAG,
#' following the order returned by igraph::V(dag).
#'
#' @return An object of class "scm".
#'
#' @examples
#' \dontrun{
#' dag <- igraph::graph_from_literal(A -+ B, B -+ C)
#' node_functions <- list(function(n) return(rep(0.5, n)),
#'                        function(parent_data) return(plogis(0.5 * parent_data)),
#'                        function(parent_data) return(plogis(0.5 - 1 * parent_data)),
#' scm <- new_scm(dag, node_functions)
#' }
new_binary_scm <- function(dag, node_functions) {
  stopifnot(all(sapply(node_functions, function(x) return(inherits(x, "binary_root_node_function") | inherits(x, "logit_node_function")))))

  base_scm <- new_scm(dag, node_functions)

  class(base_scm) <- c("binary_scm", class(base_scm))
  return(base_scm)
}

generate_data.binary_scm <- function(scm, n_samples, seed = NULL) {
  if (!is.null(seed)) {
    set.seed(seed)
  }

  dag <- scm$dag
  node_functions <- scm$node_functions

  n_nodes <- igraph::vcount(dag)
  data <- matrix(0, nrow = n_samples, ncol = n_nodes)
  colnames(data) <- igraph::V(dag)$name

  topo_order <- as.numeric(igraph::topo_sort(dag))
  for (node_idx in topo_order) {
    parents <- igraph::neighbors(dag, node_idx, mode = "in")
    if (length(parents) == 0) {
      probability_true <- node_functions[[node_idx]](n_samples)
      data[, node_idx] <- runif(n_samples) <= probability_true
    } else {
      parents_idxs <- as.numeric(parents)
      parent_outputs <- data[, parents_idxs, drop = FALSE]

      probability_true_given_parent_values <- node_functions[[node_idx]](parent_outputs)
      data[, node_idx] <- runif(n_samples) <= probability_true_given_parent_values
    }
  }
  return(data)
}

generate_data_under_intervention.binary_scm <- function(scm, n_samples, interventions, seed = NULL) {
  if (!is.null(seed)) {
    set.seed(seed)
  }
  dag <- scm$dag
  node_functions <- scm$node_functions

  n_nodes <- igraph::vcount(dag)
  data <- matrix(0, nrow = n_samples, ncol = n_nodes)
  colnames(data) <- igraph::V(dag)$name

  topo_order <- as.numeric(igraph::topo_sort(dag))
  for (node_idx in topo_order) {
    if (as.character(node_idx) %in% names(interventions)) {
      data[, node_idx] <- interventions[[as.character(node_idx)]]
      next
    }

    parents <- igraph::neighbors(dag, node_idx, mode = "in")

    if (length(parents) == 0) {
      probability_true <- node_functions[[node_idx]](n_samples)
      data[, node_idx] <- runif(n_samples) <= probability_true
    } else {
      parents_idxs <- as.numeric(parents)

      parent_outputs <- data[, parents_idxs, drop = FALSE]

      probability_true_given_parent_values <- node_functions[[node_idx]](parent_outputs)
      data[, node_idx] <- runif(n_samples) <= probability_true_given_parent_values
    }
  }
  return(data)
}

get_binary_root_node_function <- function(p) {
  stopifnot(is.numeric(p))
  stopifnot(p >= 0 & p <= 1)

  func <- function(n) {
    return(rep(p, n))
  }
  class(func) <- "binary_root_node_function"
  attr(func, "p") <- p
  return(func)
}

get_logit_function <- function(intercept, parent_weights) {
  stopifnot(is.numeric(intercept))
  stopifnot(is.numeric(parent_weights))

  # We assume each row of parent_data to correspond to a single sample: a value for each parent as specified by the DAG.
  func <- function(parent_data) {
    return(plogis(as.vector(intercept + parent_data %*% parent_weights)))
  }

  attr(func, "n_parents") <- length(parent_weights)
  attr(func, "intercept") <- intercept
  attr(func, "parent_weights") <- parent_weights
  class(func) <- "logit_node_function"
  return(func)
}


sample_logit_scm_given_dag <- function(dag, binary_root_node_weight_function,
                                       intercept_weight_function, parent_weight_function, seed = NULL) {
  if (!is.null(seed)) {
    set.seed(seed)
  }
  n_nodes <- igraph::vcount(dag)

  node_functions <- vector("list", length = n_nodes)
  names(node_functions) <- igraph::V(dag)$name

  nodes <- igraph::V(dag)
  for (node_idx in nodes) {
    n_parents <- length(igraph::neighbors(dag, node_idx, mode = "in"))
    if (n_parents == 0) {
      p <- binary_root_node_weight_function(1)
      node_functions[[node_idx]] <- get_binary_root_node_function(p)
    } else {
      intercept <- intercept_weight_function(1)
      parent_weights <- parent_weight_function(n_parents)
      node_functions[[node_idx]] <- get_logit_function(intercept, parent_weights)
    }
  }
  return(new_binary_scm(dag, node_functions))
}


print.binary_scm <- function(x, ...) {
  scm <- x
  dag <- scm$dag
  node_functions <- scm$node_functions
  topo_order <- as.numeric(igraph::topo_sort(dag))
  for (node_idx in topo_order) {
    node_name <- igraph::V(dag)$name[[node_idx]]
    node_func <- node_functions[[node_idx]]
    if (identical(attr(node_func, "class"), "binary_root_node_function")) {
      cat(sprintf("%s <- ", node_name))
      cat(sprintf("Bernoulli(%.2f)", attr(node_func, "p")))
    }
    else {
      cat(sprintf("%s <- ", node_name))
      cat(sprintf("sigmoid(%.2f + ", attr(node_func, "intercept")))
      n_parents <- attr(node_func, "n_parents")
      weights <- attr(node_func, "parent_weights")

      parent_idxs <- igraph::neighbors(dag, node_idx, mode = "in")
      for (i in seq_len(n_parents)) {
        parent_name <- igraph::V(dag)$name[[parent_idxs[i]]]
        if (grepl("^[0-9]+$", parent_name)) {
          parent_name <- paste0("x_", parent_name)
        }

        cat(sprintf("%s * %.2f", parent_name, weights[i]))
        if (i < n_parents) {
          cat(" + ")
        }
      }
      cat(")")
    }
    cat("\n")
  }
}
