import os
import pickle

import pandas as pd
from dgl.data import TUDataset
import torch
from dgl.data.utils import download

from dataset_loader.utils import filter_nb_nodes


class DBLP_v1(TUDataset):
    def __init__(self, min_node=4, max_node=15, raw_dir=None, force_reload=True, verbose=False):
        # force_reload=True as the filtering and all is done while loading
        self.min_node = min_node
        self.max_node = max_node
        super(DBLP_v1, self).__init__(name='DBLP_v1',
                                      raw_dir=raw_dir,
                                      force_reload=force_reload,
                                      verbose=verbose)

    def _download_embeddings(self):
        self.embedding_pkl_file_path = os.path.join(self.raw_path, 'word_embeddings.pkl')

    def _load_embeddings(self):
        embedding_data = pickle.load(open('word_embeddings.pkl', 'rb'))
        self.embeddings = dict(zip(embedding_data['vocab'], embedding_data['weights']))

    def _load_topics(self):
        # Load topics | sed -n 50,41374p DBLP_v1/readme.txt > DBLP_v1/map.tsv
        map_df = pd.read_csv(
            os.path.join(self.raw_path, 'DBLP_v1', 'readme.txt'),
            sep="\t", header=None, skiprows=49, nrows=41374 - 49
        ).rename(columns={1: 'id', 2: 'val'})[['id', 'val']].set_index('id')
        map_df['is_digit'] = map_df['val'].apply(lambda v: v.isdigit())
        self.map_df = map_df

    # Extending base functions of DGLDataset
    # Ref.: https://docs.dgl.ai/generated/dgl.data.DGLDataset.html#dgl.data.DGLDataset
    def process(self):
        super().process()

        # ADDING CUSTOM METHODS TO FILTER AND COMPUTE ATTRIBUTES
        self._download_embeddings()
        self._load_embeddings()
        self._load_topics()

        EMBEDDING_DIM = 50  # using GloVe with 50d embeddings
        # Filter on number of nodes (wrapped around zip / unzip of graphs with labels)
        graphs, labels = filter_nb_nodes(self.graph_lists, self.min_node, self.max_node, list(self.graph_labels))

        for g in graphs:
            # Collect additionnal information on the nodes
            # Node type (topic or paper) & word2vec embedding for topics
            # M1/ 0: P2P; 1: P2W; 2: W2W => avg and compare to 1
            # M2/ Enrich node IDs

            def nodes_udf(nodes):
                l = nodes.data['node_labels']  # shape [nb_nodes, 1]
                l = torch.reshape(l, (-1,))
                # Build a [nb_nodes, 1] tensor for is_topic
                # Build a [nb_nodes, N] tensor for word embedding
                is_topic = torch.tensor([[int(not self.map_df.iloc[node.item()].is_digit)] for node in l])
                embedding = torch.zeros((len(nodes.nodes()), EMBEDDING_DIM))  # if not topic -> zero array embedding
                for i, node in enumerate(l):
                    if not self.map_df.iloc[node.item()].is_digit:
                        word = self.map_df.iloc[node.item()].val
                        embedding[i] = torch.tensor(self.embeddings[word])
                return {'is_topic': is_topic, 'embedding': embedding}

            g.apply_nodes(nodes_udf)

            g.ndata['attr'] = torch.cat([g.ndata['is_topic'], g.ndata['embedding']], dim=1)

        self.graph_lists = graphs
        self.graph_labels = torch.tensor(labels).reshape((len(labels), 1))

    @property
    def graphs(self):
        return self.graph_lists

    @property
    def label(self):
        return self.graph_labels

    @property
    def n_labels(self):
        """Number of labels for each graph, i.e. number of prediction tasks."""
        return 2


if __name__ == '__main__':
    dataset = DBLP_v1(min_node=4, max_node=5)
    print(dataset.graphs[0])
