from dgl.data import QM7bDataset
import dgl
import torch
from sklearn.preprocessing import MinMaxScaler, StandardScaler, OneHotEncoder
import numpy as np
import networkx as nx


from dataset import load_dataset



class ZincDataset(QM7bDataset):

    # OVERWRITING TO LOAD QM9 INSTEAD OF QM7b

    def __init__(self, min_node=4, max_node=15, raw_dir=None, force_reload=True, verbose=False, split='train', embed_z=True):
        # force_reload=True as the filtering and all is done while loading
        self.min_node = min_node
        self.max_node = max_node
        self.encoder = OneHotEncoder(sparse=False)
        self.encoder.fit(np.arange(28).reshape((-1, 1)))
        self.split = split
        self.embed_z = embed_z
        super(QM7bDataset, self).__init__(name='zinc',
                                          url=None,
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose)

    def process(self):
        self.graphs, self.label = self._load_graph()

    def _load_graph(self):
        # OVERWRITING TO LOAD ZINC INSTEAD OF QM7b

        graphs, targets = load_dataset(f'ZINC_{self.split}', min_node=self.min_node, max_node=self.max_node, node_attributes=True)
        graphs = [nx.convert_node_labels_to_integers(G) for G in graphs]

        dgl_graphs = [dgl.from_networkx(graph, node_attrs=['label']) for graph in graphs]
        for graph in dgl_graphs:
            if self.embed_z:
                graph.ndata['attr'] = torch.tensor(self.encoder.transform(graph.ndata['label'].numpy().reshape((-1, 1))))
            else:
                graph.ndata['attr'] = graph.ndata['label'].reshape((-1, 1))

        return dgl_graphs, torch.tensor(targets, dtype=torch.float).reshape((-1, 1))

    def download(self):
        pass

    # @property
    # def n_labels(self):
    #     """Number of labels for each graph, i.e. number of prediction tasks."""
    #     return 19