## This file contains all functions related to graphs

library(pcalg)
suppressPackageStartupMessages(library(igraph))


#' Generate an Erdos-Renyi DAG with a final 'predictive' sink node.
#'
#' This function generates Erdos-Renyi DAG as an igraph graph, with a final
#' 'predictive' sink node. We can specify the total number of nodes (including 
#' the predictive node), the expected number of neighbors of each node, and the 
#' expected number of parents of the predictive node. This function makes use 
#' of the pcalg::randDAG function to generate the DAG (without y).
#'
#' @param n_nodes The total number of nodes the graph will contain.
#' @param n_expected_neighbors The expected number of neighbors of each node,
#' possibly excluding the predictive node if n_expected_parents_y is set.
#' @param n_expected_parents_y The expected number of parents the predictive 
#' node y will have. If NULL, this will be set to n_expected_neighbors.
#'
#' @return An object of class "igraph".
#' 
#' @examples
#' \dontrun{
#' generate_erdosrenyi_dag_sink_y(n_nodes = 10,
#'                                       n_expected_neighbors = 2,
#'                                       n_expected_parents_y = 5)
#' }
generate_erdosrenyi_dag_sink_y <- function(n_nodes,
                                           n_expected_neighbors,
                                           n_expected_parents_y = NULL,
                                           min_parents_y = 2,
                                           seed = NULL) {

  # Checks
  if (n_expected_neighbors > n_nodes - 1) {
    warning("The expected number of neighbors cannot be larger than the number of nodes minus 1. Setting n_expected_neighbors to n_nodes - 1.")
    n_expected_neighbors <- n_nodes - 1
  }
  if (n_expected_parents_y == n_nodes - 1) {
    warning("The expected number of parents of the predictive node is equal to the number of nodes minus 1. Amount of parents will be equal to n_nodes - 1.")
  }
  if (n_expected_parents_y > n_nodes - 1) {
    warning("The expected number of parents of the predictive node cannot be larger than the number of nodes minus 1. Setting n_expected_parents_y to n_nodes - 1.")
    n_expected_parents_y <- n_nodes - 1
  }
  if (n_expected_parents_y < min_parents_y) {
    warning("The expected number of parents of the predictive node cannot be smaller than min_parents_y. Setting n_expected_parents_y to min_parents_y.")
    n_expected_parents_y <- min_parents_y
  }

  if (is.null(n_expected_parents_y)) {
    n_expected_parents_y <- n_expected_neighbors
  }

  if (!is.null(seed)) {
    set.seed(seed)
  }

  # Give every edge a weight of 1 (we don't make use of them)
  weight_function <- function(n) rep(1, n)

  n_nodes_to_sample <- n_nodes - 1

  dag <- pcalg::randDAG(n_nodes_to_sample, n_expected_neighbors, wFUN = weight_function)
  dag <- igraph::graph_from_graphnel(dag)

  # add y node
  dag <- igraph::add_vertices(dag, 1, name = n_nodes)
  node_y <- igraph::V(dag)[[n_nodes]]
  node_y_idx <- n_nodes

  n_parents_y <- 0
  n_parents_to_add <- max(min_parents_y, rbinom(n = 1,
                                                size = n_nodes - 1,
                                                prob = n_expected_parents_y / (n_nodes - 1)))
  n_parents_to_add <- min(n_parents_to_add, n_nodes - 1)

  graph_components <- igraph::components(dag)
  n_graph_components <- graph_components$no

  # Add at least edge from each component to Y
  while (n_graph_components > 1) {
    # Get unconnected component that does not include Y
    for (component_idx in seq_len(n_graph_components)) {
      component_nodes <- graph_components$membership == component_idx
      # Exit on the first component that does not have Y in it
      if (isFALSE(component_nodes[[node_y_idx]])) {
        break
      }
    }
    component_nodes <- graph_components$membership == component_idx

    # Get index of random node in unconnected component
    # We should only use sample when we have multiple nodes to choose from;
    # see first paragraph of Details of ?sample
    if (sum(component_nodes) == 1) {
      random_component_node_idx <- which(component_nodes)
    }
    else {
      random_component_node_idx <- sample(which(component_nodes), 1)
    }

    dag <- igraph::add_edges(dag, c(random_component_node_idx, node_y_idx), weight = 1)
    n_parents_y <- n_parents_y + 1

    # Update graph components
    graph_components <- igraph::components(dag)
    n_graph_components <- graph_components$no
  }

  # Add more parents if necessary
  while (n_parents_to_add - n_parents_y > 0) {
    parents_y <- igraph::neighborhood(dag, order = 1, nodes = node_y_idx, mode = "in")[[1]]
    random_node_idx <- sample(seq_len(n_nodes - 1), 1)
    random_node <- igraph::V(dag)[[random_node_idx]]
    if (length(igraph::intersection(random_node, parents_y)) == 0) {
      dag <- igraph::add_edges(dag, c(random_node_idx, node_y_idx), weight = 1)
      n_parents_y <- n_parents_y + 1
    }
  }

  # Set (feature node) names to form Xi where i is the current name (which is its index)
  new_names <- paste0("X", igraph::V(dag)$name)
  igraph::V(dag)$name <- new_names
  igraph::V(dag)[[node_y_idx]]$name <- "Y"

  return(dag)
}


#' Function to remove all disconnected graph components except for
#' the largest one. Used in graph generation.
remove_smaller_components <- function(igraph) {
  comp <- igraph::components(igraph)
  largest_comp_id <- which.max(comp$csize) # Identify the largest component

  # Get nodes belonging to the largest component
  nodes_to_keep <- igraph::V(igraph)[comp$membership == largest_comp_id]

  g_sub <- igraph::induced_subgraph(igraph, nodes_to_keep)

  return(g_sub)
}

#' Function to safely remove a node from a (causal) PDAG. Note that this has to
#' be used carefully. In the case of mediator relationships X -> Y -> Z, when 
#' removing Y we add a directed edge X -> Z. In the case of a v-structure 
#' X -> Y <- Z we simply remove Y and the two edges. In the case of Y being a
#' common cause X <- Y -> Z the relationship is more complicated. In the case
#' of Causal Shapley values, since we only perform do-interventions on features,
#' we can safely remove the variable in the case of Y being a common cause, 
#' as a do intervention on X will not affect Z when Y is a common cause, and 
#' vice versa. This is assuming this function will only be used to remove the variable
#' the model is predicting.
remove_node <- function(graph, node_to_remove) {
  # Create a copy of the graph
  updated_graph <- graph

  # Check for mediator relationships
  # If 'A -> node_to_remove -> B', introduce a new edge 'A -> B'
  parents <- igraph::neighbors(graph, node_to_remove, mode = "in")
  children <- igraph::neighbors(graph, node_to_remove, mode = "out")

  for (parent in parents) {
    for (child in children) {
      if (!igraph::are_adjacent(updated_graph, parent, child)) {
        updated_graph <- igraph::add_edges(updated_graph, c(parent, child))
      }
    }
  }
  # Remove the node and its edges
  node_to_remove <- igraph::V(updated_graph)[node_to_remove]
  updated_graph <- igraph::delete_vertices(updated_graph, node_to_remove)

  return(updated_graph)
}