import torch
import networkx as nx
import math
import random
import numpy as np

from graph_learning.dataset.graph import GLGraph, edge_batch
from graph_learning.dataset import DatasetConfig
import dgl

class DataList(list):
    pass

@DatasetConfig.register('dist')
class DistDatasetConfig(DatasetConfig):
    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--mode', choices=['all', 'each'])
        parser.add_argument('--depth', type=int, default=4)

    def build_graph(self, n, lb, ub, name):
        p = random.uniform(lb, ub)
        k = int(random.uniform(math.log(n), 2*math.log(n)))

        g = nx.connected_watts_strogatz_graph(n, k, p).to_directed()

        for (u, v) in g.edges():
            g.edges[u, v]['weight'] = float(random.randint(2, 10))

        # pair_dists = torch.zeros(n, self.depth, n).fill_(np.nan)

        # for l in range(self.depth):
        #     length = dict(nx.all_pairs_dijkstra_path_length(g, l))
        #     for i in range(n):
        #         for j in range(n):
        #             try:
        #                 pair_dists[i][l][j] = length[i][j]
        #             except Exception as e:
        #                 pass

        dgl_g = dgl.from_networkx(g, edge_attrs=['weight'])
        dgl_g.ndata['x'] = torch.ones(dgl_g.number_of_nodes(), 1)

        gl_g = GLGraph(dgl_g)
        gl_g.gdata['name'] = f'dist_{name}'

        # gl_g.ndata['pair_labels'] = pair_dists

        return gl_g

    def build_dataset(self):
        if self.mode == 'all':
            train_num = 120
            valid_num = 40
            test_num = 40
            lb, ub = 0.2, 0.3
            data = DataList(
                [self.build_graph(100, lb, ub, f'train_{i}') for i in range(train_num)] +
                [self.build_graph(200, lb, ub, f'valid_{i}') for i in range(valid_num)] +
                [self.build_graph(200, lb, ub, f'test_{i}') for i in range(test_num)])

            setattr(data, 'train_index', train_num)
            setattr(data, 'valid_index', train_num + valid_num)
        elif self.mode == 'each':
            data = self.build_graph(2000, 0.2, 0.3, 'each')

        return data
