from dgl.data import QM7bDataset
import dgl
import torch
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import networkx as nx


from dataset import load_dataset



class QM9Dataset(QM7bDataset):

    # OVERWRITING TO LOAD QM9 INSTEAD OF QM7b

    def __init__(self, min_node=4, max_node=15, raw_dir=None, force_reload=True, verbose=False):
        # force_reload=True as the filtering and all is done while loading
        self.min_node = min_node
        self.max_node = max_node
        self.scaler = StandardScaler()
        super(QM7bDataset, self).__init__(name='qm9',
                                          url=None,
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose)

    def process(self):
        self.graphs, self.label = self._load_graph()

    def _load_graph(self):
        # OVERWRITING TO LOAD QM9 INSTEAD OF QM7b

        graphs, targets = load_dataset('QM9', min_node=self.min_node, max_node=self.max_node, node_attributes=True)
        graphs = [nx.convert_node_labels_to_integers(G) for G in graphs]

        dgl_graphs = [dgl.from_networkx(graph, node_attrs=['attr']) for graph in graphs]
        targets = self.scaler.fit_transform(targets)


        return dgl_graphs, torch.tensor(targets, dtype=torch.float)

    def download(self):
        pass

    # @property
    # def n_labels(self):
    #     """Number of labels for each graph, i.e. number of prediction tasks."""
    #     return 19