from dgl.data import QM7bDataset
from scipy import io
import dgl
import torch
from dgl.convert import graph as dgl_graph
from dataset_loader.utils import filter_nb_nodes
import numpy as np


def is_true(val):
    # Check if arg val (usually passed as str through CLI but perhaps updated) evaluates to true
    return val or val == 'True' or val == 'true' or val == '1' or val == 1


class QM7Dataset(QM7bDataset):
    _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/qm7.mat'
    _sha1_str = 'a877ec22944ae62f42c3f264f14e14789b2543e1'

    def __init__(self, use_positions=True, embed_z=False, min_node=4, max_node=15, raw_dir=None, force_reload=True, verbose=False):
        # force_reload=True as the filtering and all is done while loading
        self.use_positions = is_true(use_positions)
        self.embed_z = is_true(embed_z)
        self.min_node = min_node
        self.max_node = max_node
        super(QM7bDataset, self).__init__(name='qm7',
                                          url=self._url,
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose)

    def _load_graph(self, filename):
        # OVERWRITING TO LOAD QM7 INSTEAD OF QM7b
        data = io.loadmat(filename)
        # keys 'X', 'R', 'Z', 'T', 'P'
        labels = dgl.backend.tensor(data['T'], dtype=dgl.backend.data_type_dict['float32']).reshape(-1, 1)
        feats = data['X']
        num_graphs = labels.shape[0]
        graphs = []
        for i in range(num_graphs):
            edge_list = feats[i].nonzero()
            g = dgl_graph(edge_list)
            g.edata['h'] = dgl.backend.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),
                                    dtype=dgl.backend.data_type_dict['float32'])
            nb_nodes = g.num_nodes()
            g.ndata['R'] = dgl.backend.tensor(data['R'][i][:nb_nodes])
            g.ndata['Z'] = dgl.backend.tensor(data['Z'][i][:nb_nodes])
            graphs.append(g)

        # ADDING CUSTOM METHODS TO FILTER AND COMPUTE ATTRIBUTES
        # Filter on number of nodes (wrapped around zip / unzip of graphs with labels)
        graphs, labels = filter_nb_nodes(graphs, self.min_node, self.max_node, list(labels))

        # For QM7, atomic number and coordinates are usable | Ref.: https://stackoverflow.com/a/66301026/10115198
        self.different_z = list(np.unique(data['Z']))
        for g in graphs:
            g.ndata['Z_one_hot'] = torch.stack(list(map(self.one_hot, g.ndata['Z'].numpy())))
            z = g.ndata['Z_one_hot'] if self.embed_z else g.ndata['Z'].unsqueeze(1)
            if self.use_positions:
                g.ndata['attr'] = torch.hstack((z, g.ndata['R']))
            else:
                g.ndata['attr'] = z
        return graphs, torch.tensor(labels).reshape((len(labels), 1))

    def one_hot(self, val):
        i = self.different_z.index(val)
        n = len(self.different_z)
        return torch.zeros(n).scatter_(0, torch.tensor([i]), 1)

    @property
    def n_labels(self):
        """Number of labels for each graph, i.e. number of prediction tasks."""
        return 1


if __name__ == '__main__':
    qm7_dataset = QM7Dataset(min_node=4, max_node=5)
    print(qm7_dataset.graphs[0])
