from dgl.data import QM7bDataset, DGLDataset
import dgl
import torch
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import networkx as nx
import numpy as np
from .GraphCoversRepo.covers import gen_graphCovers


from dataset import load_dataset


def read_graphs(raw_data):
    graphs = []

    for list_edges in raw_data:
        g = nx.Graph()
        g.add_nodes_from(np.arange(10))
        for edge in list_edges:
            g.add_edge(edge[0], edge[1])
        graphs.append(g)
    return graphs


class GraphCoversGeneration(DGLDataset):

    # OVERWRITING TO LOAD QM9 INSTEAD OF QM7b

    def __init__(self, min_node=4, max_node=15, raw_dir=None, force_reload=True, verbose=False):
        # force_reload=True as the filtering and all is done while loading
        self.min_node = min_node
        self.max_node = max_node
        self.scaler = StandardScaler()
        super().__init__(name='cover-gen',
                         raw_dir=raw_dir,
                         force_reload=force_reload,
                         verbose=verbose)

    def process(self):
        self.graphs, self.label = self._load_graph()

    def _load_graph(self):
        edge_index = [
            [1, 2], [2, 1],
            [2, 3], [3, 2],
            [2, 4], [4, 2],
            [3, 4], [4, 3],
            [4, 5], [5, 4],
            [5, 6], [6, 5],
            [5, 7], [7, 5],
            [6, 7], [7, 6]
        ]
        edge_index = list(map(lambda l: [i - 1 for i in l], edge_index))
        cycle_edge = [[1, 3], [4, 5]]

        graph_covers = gen_graphCovers(edge_index, degree=3, cycle_edge=cycle_edge, nb_covers=6)
        dgl_graphs = []
        for cover in graph_covers:
            graph = dgl.from_networkx(cover.nxGraph)
            graph.ndata['attr'] = torch.ones((graph.num_nodes(), 1))
            dgl_graphs.append(graph)

        # make three classes:
        targets = torch.zeros(len(dgl_graphs), dtype=torch.long)
        targets[len(dgl_graphs) // 3:] = 1
        targets[len(dgl_graphs) // 3 * 2:] = 2

        # make 2 classes:
        # targets = torch.zeros(len(dgl_graphs), dtype=torch.long)
        # targets[len(dgl_graphs) // 2:] = 1

        return dgl_graphs, targets

    def download(self):
        pass

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, i):
        return self.graphs[i], self.label[i]

    @property
    def n_labels(self):
        """Number of labels for each graph, i.e. number of prediction tasks."""
        return 3

    # @property
    # def n_labels(self):
    #     """Number of labels for each graph, i.e. number of prediction tasks."""
    #     return 19