library(pcalg, quietly = TRUE)

source('R/cpdag_enumerate.R')

#' Function to convert an igraph CPDAG, to a list of igraph DAGs, each of which is
#' a member of the Markov Equivalence Class that is represented by the CPDAG.
dag_to_dags_in_cpdag <- function(igraph_dag) {
  # Use pcalg to convert the dag into its CPDAG; pcalg requires a graphnel object
  cpdag <- pcalg::dag2cpdag(igraph::igraph.to.graphNEL(igraph_dag)) # graph object
  cpdag_adj <- graph::adjacencyMatrix(as(cpdag, "graphBAM")) # to adjacency matrix in form [a,b] = 1 if a->b
  dags_adjacency_matrices <- enumerate_cpdag(cpdag_adj) # enumerate_cpdag requires an adjacency matrix
  dags <- vector("list", length(dags_adjacency_matrices))
  for (dag_idx in seq_along(dags_adjacency_matrices)) {
    adj_matrix <- dags_adjacency_matrices[[dag_idx]]
    current_dag <- igraph::graph_from_adjacency_matrix(adj_matrix, mode = "directed")
    dags[[dag_idx]] <- current_dag
  }
  return(dags)
}


#' Function to convert the output of pcalg::pc, which is a CPDAG, to a list of igraph DAGs, each of which is
#' a member of the Markov Equivalence Class that is represented by the CPDAG.
pc_fit_to_dags_in_cpdag <- function(pc.fit, cpdag2dag_method = 'ours') {

  cpdag_object <- as(pc.fit, "amat")
  cpdag_object <- t(cpdag_object) # transpose such that [a,b] = 1 if a->b
  cpdag_igraph <- igraph::graph_from_adjacency_matrix(cpdag_object, mode = "directed")

  dags_adjacency_matrices <- enumerate_cpdag(cpdag_object)
  n_dags <- length(dags_adjacency_matrices)
  igraph_dags <- vector("list", n_dags)
  for (dag_idx in seq(n_dags)) {
    adj_matrix <- dags_adjacency_matrices[[dag_idx]]
    igraph_dags[[dag_idx]] <- igraph::graph_from_adjacency_matrix(adj_matrix, mode = "directed")
  }
  return(list(cpdag = cpdag_igraph, dags = igraph_dags))
}

adj_mat_cpdag_to_dags_in_cpdag <- function(cpdag_adj_mat) {
  cpdag_igraph <- igraph::graph_from_adjacency_matrix(as.matrix(cpdag_adj_mat), mode = "directed")

  dags_adjacency_matrices <- enumerate_cpdag(cpdag_adj_mat)
  n_dags <- length(dags_adjacency_matrices)
  igraph_dags <- vector("list", n_dags)
  for (dag_idx in seq(n_dags)) {
    adj_matrix <- dags_adjacency_matrices[[dag_idx]]
    igraph_dags[[dag_idx]] <- igraph::graph_from_adjacency_matrix(adj_matrix, mode = "directed")
  }
  return(list(cpdag = cpdag_igraph, dags = igraph_dags))
}


#' Function to convert an igraph DAG to a CPDAG adjacency matrix (in the form
#' [b,a] = 1 if a->b, as required by pcalg).
true_dag_to_cpdag_adjmatrix <- function(igraph_true_dag) {
  cpdag <- pcalg::dag2cpdag(igraph::igraph.to.graphNEL(igraph_true_dag)) # graph object
  cpdag_adj <- graph::adjacencyMatrix(as(cpdag, "graphBAM")) # to adjacency matrix in form [a,b] = 1 if a->b
  cpdag_adj <- t(cpdag_adj) # transpose such that [b,a] = 1 if a->b (as required by pcalg)
  return(cpdag_adj)
}

#' Given an input DAG, return a list containing each DAG in its Markov Equivalence Class.
#' We use the pcalg method pdag2alldags to generate all DAGs in the MEC. This function may not work for
#' 15 or more nodes, which is why this function is deprecated.
dag_to_dags_in_cpdag_pcalg <- function(igraph_dag) {
  cpdag_adj <- true_dag_to_cpdag_adjmatrix(igraph_dag)
  dags <- pcalg::pdag2allDags(cpdag_adj)
  n_dags_in_cpdag <- nrow(dags$dags)
  dags_in_cpdag <- vector("list", n_dags_in_cpdag)

  for (dag_idx in seq(n_dags_in_cpdag)) {
    adj_matrix <- matrix(dags$dags[dag_idx,], length(igraph::V(igraph_dag)), length(igraph::V(igraph_dag)), byrow=TRUE)
    adj_matrix <- t(adj_matrix)  # amat type from pcalg follows convention: amat[b,a]=1 if a->b, igraph the opposite
    colnames(adj_matrix) <- rownames(adj_matrix) <- dags$nodeNms
    current_dag <- igraph::graph_from_adjacency_matrix(adj_matrix, mode = "directed")
    dags_in_cpdag[[dag_idx]] <- current_dag
  }
  return(dags_in_cpdag)
}



#' Function to validate the output of pcalg. Essentially checks whether
#' the CPDAG contains any DAGs (which is not always the case).
validate_causal_discovery <- function(pc.fit) {
  # Either we get a true pc.fit object (outcome of pcalg::pc), or a matrix (causal discovery oracle)
  if ("pcAlgo" %in% class(pc.fit)) {
    cpdag_object <- as(pc.fit, "amat")
  }
  else {
    cpdag_object <- pc.fit
  }

  dags <- pcalg::pdag2allDags(cpdag_object)

  if (is.null(dags$dags)) {
    return(FALSE)
  }
  return(TRUE)
}

#' Function to check whether the CPDAG as returned by pcalg::pc contains a true DAG.
is_true_dag_in_cpdag <- function(pc.fit, true_dag) {
  cpdag_object <- as(pc.fit, "amat")

  dags <- pcalg::pdag2allDags(cpdag_object)
  n_dags <- nrow(dags$dags)

  for (dag_idx in seq(n_dags)) {
    adj_matrix <- matrix(dags$dags[dag_idx,], length(true_dag), length(true_dag), byrow=TRUE)
    adj_matrix <- t(adj_matrix)  # amat type from pcalg follows convention: amat[b,a]=1 if a->b, igraph the opposite
    colnames(adj_matrix) <- rownames(adj_matrix) <- dags$nodeNms

    current_dag <- igraph::graph_from_adjacency_matrix(adj_matrix, mode = "directed")

    if (igraph::isomorphic(true_dag, current_dag)) {
      return(TRUE)
    }
  }
  return(FALSE)
}
