from dgl.data import QM7bDataset, DGLDataset
import dgl
import torch
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import networkx as nx
import numpy as np


from dataset import load_dataset


def read_graphs(raw_data):
    graphs = []

    for list_edges in raw_data:
        g = nx.Graph()
        g.add_nodes_from(np.arange(10))
        for edge in list_edges:
            g.add_edge(edge[0], edge[1])
        graphs.append(g)
    return graphs


class GraphCovers(DGLDataset):

    # OVERWRITING TO LOAD QM9 INSTEAD OF QM7b

    def __init__(self, min_node=4, max_node=15, raw_dir=None, force_reload=True, verbose=False):
        # force_reload=True as the filtering and all is done while loading
        self.min_node = min_node
        self.max_node = max_node
        self.scaler = StandardScaler()
        super().__init__(name='cover',
                                          
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose)

    def process(self):
        self.graphs, self.label = self._load_graph()

    def _load_graph(self):
        # OVERWRITING TO LOAD QM9 INSTEAD OF QM7b

        graphs_raw_0 = np.load('../datasets/GraphCovers/graphs_0.npy')
        graphs_raw_1 = np.load('../datasets/GraphCovers/graphs_1.npy')

        graphs_0 = read_graphs(graphs_raw_0)
        graphs_1 = read_graphs(graphs_raw_1)

        targets = torch.zeros(len(graphs_0) + len(graphs_1), dtype=torch.long)
        targets[len(graphs_0): len(graphs_0) + len(graphs_1)] = 1
        dgl_graphs = []
        for g in graphs_0 + graphs_1:
            graph = dgl.from_networkx(g)
            deg = graph.out_degrees().cpu().numpy()
            L = np.diag(deg) - graph.adj().to_dense().cpu().numpy()
            l, v = np.linalg.eigh(L)
           # graph.ndata['attr'] = torch.ones((graph.num_nodes(), 1))
            graph.ndata['attr'] = torch.tensor(v[:, 0:18],
                                               dtype=torch.float)
            dgl_graphs.append(graph)

        return dgl_graphs, targets

    def download(self):
        pass

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, i):
        return self.graphs[i], self.label[i]

    @property
    def n_labels(self):
        """Number of labels for each graph, i.e. number of prediction tasks."""
        return 2

    # @property
    # def n_labels(self):
    #     """Number of labels for each graph, i.e. number of prediction tasks."""
    #     return 19