import graphviz


def draw(edge_ids, relation_types, graph, seed_node=None, **kwargs):
    dot = graphviz.Digraph(engine="dot")
    for (head, tail), rel in zip(edge_ids.T, relation_types):
        dot.node(
            str(head),
            label=graph.node_labels[head],
            color="red" if head == seed_node else "black",
        )
        dot.node(
            str(tail),
            label=graph.node_labels[tail],
            color="red" if tail == seed_node else "black",
        )
        dot.edge(str(head), str(tail), label=graph.relation_labels[rel])
    return dot
