import os
import pickle

import pandas as pd
from dgl.data import TUDataset
import torch
from dgl.data.utils import download

from dataset_loader.utils import filter_nb_nodes, filter_only_connected, is_true


class Letter(TUDataset):
    _url = 'https://www.chrsmrrs.com/graphkerneldatasets/Letter-med.zip'
    _sha1_str = 'a877ec22944ae62f42c3f264f14e14789b2543e1'

    def __init__(self, only_connected=False, distorsion='med', min_node=4, max_node=15, raw_dir=None, force_reload=True, verbose=False):
        # force_reload=True as the filtering and all is done while loading
        self.min_node = min_node
        self.max_node = max_node
        self.only_connected = is_true(only_connected)
        assert distorsion in ['low', 'med', 'high']
        super(Letter, self).__init__(name=f'Letter-{distorsion}',
                                     raw_dir=raw_dir,
                                     force_reload=force_reload,
                                     verbose=verbose)

    # Extending base functions of DGLDataset
    # Ref.: https://docs.dgl.ai/generated/dgl.data.DGLDataset.html#dgl.data.DGLDataset
    def process(self):
        super().process()

        # ADDING CUSTOM METHODS TO FILTER AND COMPUTE ATTRIBUTES
        # Filter on number of nodes (wrapped around zip / unzip of graphs with labels)
        graphs, labels = filter_nb_nodes(self.graph_lists, self.min_node, self.max_node, list(self.graph_labels))
        print(len(graphs), len(labels))
        if self.only_connected:
            graphs, labels = filter_only_connected(graphs, labels)
        print(len(graphs), len(labels))

        for g in graphs:
            g.ndata['attr'] = g.ndata['node_attr']

        self.graph_lists = graphs
        self.labels = torch.tensor(labels).reshape((-1,))
        self.graph_labels = torch.tensor(labels).reshape((-1,))

    @property
    def graphs(self):
        return self.graph_lists

    @property
    def label(self):
        return self.graph_labels

    @property
    def n_labels(self):
        """Number of labels for each graph, i.e. number of prediction tasks."""
        return 15


if __name__ == '__main__':
    dataset = Letter(distorsion='med', min_node=4, max_node=5)
    print(dataset.graphs[0])