import graphviz


def draw(subgraph, knowledge_graph):
    dot = graphviz.Digraph(engine="dot")
    for edge in subgraph:
        head, tail = knowledge_graph.edge_ids[edge]
        relation = knowledge_graph.relation_types[edge]
        dot.node(str(head), label=knowledge_graph.node_labels[head])
        dot.node(str(tail), label=knowledge_graph.node_labels[tail])
        dot.edge(str(head), str(tail), label=knowledge_graph.relation_labels[relation])
    return dot


def draw_from_triples(triples, knowledge_graph):
    dot = graphviz.Digraph(engine="dot")
    for edge in triples:
        head, tail = edge[0], edge[2]
        relation = edge[1]
        dot.node(str(head), label=knowledge_graph.node_labels[head])
        dot.node(str(tail), label=knowledge_graph.node_labels[tail])
        dot.edge(str(head), str(tail), label=knowledge_graph.relation_labels[relation])
    return dot

def draw_from_wikidata_ids(triples, wiki_client):
    dot = graphviz.Digraph(engine="dot")
    for edge in triples:
        head, tail = edge[0], edge[2]
        relation = edge[1]
        query_head = list(wiki_client.query_all("qid2label", head))
        h_label = query_head[0] if len(query_head) > 0 else head[:12]
        query_tail = list(wiki_client.query_all("qid2label", tail))
        t_label = query_tail[0] if len(query_tail) > 0 else tail[:12]
        dot.node(str(h_label), label=h_label)
        dot.node(str(t_label), label=t_label)
        dot.edge(str(h_label), str(t_label), label=list(wiki_client.query_all("pid2label", relation))[0])
    return dot
