import pickle
import os
import dgl
import tqdm
import torch
import numpy as np
import networkx as nx

lists = os.listdir()

for fn in lists:
    if fn[:3] == 'max' and fn[-6:] == 'pickle':
        print(fn)
        with open(fn, "rb") as f:
            _train, _val, _test = pickle.load(f)
            train, val, test = [None for _ in range(len(_train))], [None for _ in range(len(_val))], [None for _ in range(len(_test))]
            for _bef, _aft in zip((_train, _val, _test), (train, val, test)):
                for i, _data in tqdm.tqdm(enumerate(_bef)):
                    _graph, _ans = _data
                    _feats, g = _graph.node_features, _graph.g
                    g = dgl.from_networkx(g)
                    assert(abs(g.in_degrees().max().item() - _ans) < 1e-9)
                    _aft[i] = {'graph': g, 'node_feats': _feats, 'edge_feats': None}
        with open("converted/" + fn, "wb") as ff:
            pickle.dump((train, val, test), ff)
    
    if fn[:3] == 'sho' and fn[-6:] == 'pickle':
        print(fn)
        with open(fn, "rb") as f:
            _train, _val, _test = pickle.load(f)
            train, val, test = [None for _ in range(len(_train))], [None for _ in range(len(_val))], [None for _ in range(len(_test))]
            
            for _bef, _aft in zip((_train, _val, _test), (train, val, test)):
                for ii, _data in tqdm.tqdm(enumerate(_bef)):
                    _graph, _ans = _data
                    _feats, g = _graph.node_features, _graph.g.to_directed()
                    
                    g = dgl.from_networkx(g, edge_attrs=['weight'])
                    edge_feats = g.edata.pop('weight')
                    
                    n = g.num_nodes()
                    _from, _to = map(lambda x: x.numpy(), g.edges())
                    
                    adj1 = np.ones((n,n), dtype=np.float64) * int(1e9)
                    adj1[_from, _to] = 1
                    adj1[np.arange(n), np.arange(n)] = 0
                    adj2 = np.ones((n,n), dtype=np.float64) * 1e9
                    adj2[_from, _to] = edge_feats.numpy()
                    adj2[np.arange(n), np.arange(n)] = 0.
                    
                    for k in range(n):
                        for i in range(n):
                            for j in range(n):
                                if adj1[i, k] + adj1[k, j] < adj1[i, j]:
                                    adj1[i, j] = adj1[i, k] + adj1[k, j]
                                if adj2[i, k] + adj2[k, j] < adj2[i, j]:
                                    adj2[i, j] = adj2[i, k] + adj2[k, j]
                    
                    _src, _dst = _feats[:, 1].argmax(dim=0), _feats[:, 2].argmax(dim=0)
                    dist_one = torch.FloatTensor(adj1[_src, :])
                    dist_weight = torch.FloatTensor(adj2[_src, :])
                    _aft[ii] = {'graph': g, 'node_feats': _feats[:, :2], 'edge_feats': edge_feats, 'one': dist_one, 'w': dist_weight}
        
        with open("converted/" + fn, "wb") as ff:
            pickle.dump((train, val, test), ff)
