import os
import networkx as nx
import numpy as np

def topological_sort_edges(G):
    if not(isinstance(G, nx.DiGraph)):
        raise Exception("Only implemented for networkx DiGraph !")
    if not(nx.is_directed_acyclic_graph(G)):
        raise Exception("Only implemented for Directed Acyclic Graphs !")

    order = []
    for v in nx.topological_sort(G):
        for e in G.out_edges(v):
            order.append(e)

    return order

n=12
folder = "{0}x{0}".format(n)
split = "train"
weights = np.load("./"+folder+"/"+split+"_vertex_weights.npy")

grid = nx.DiGraph()
for i in range(n):
    for j in range(n):
        if i<n-1:
            grid.add_edge((i, j), (i+1, j))
        if j<n-1:
            grid.add_edge((i, j), (i, j+1))

order = list(topological_sort_edges(grid))
E = {order[i]:i for i in range(len(order))}

labels = []

for w in weights:
    nx.set_edge_attributes(grid, {order[i]:{"logits":-(w[order[i][0]]+w[order[i][1]])} for i in range(len(order))})
    trace = nx.bellman_ford_path(grid, (0,0), (n-1, n-1), weight="logits")
    label = np.zeros(len(order))
    for i in range(len(trace)-1):
        label[E[(trace[i], trace[i+1])]] = 1
    labels.append(label)

labels = np.stack(labels, axis=0)
np.save("./"+folder+"/"+split+"_diag_labels.npy", labels)