import torch

from graph_learning.dataset.graph import GLGraph
from graph_learning.dataset import DatasetConfig
import networkx as nx
import numpy as np
import dgl

def name_batch(graphs, names):
    return names[0][:names[0].find('_')]
def name_unbatch(graph, base_name):
    n = len(dgl.unbatch(graph))
    return [f'{base_name}_{i}' for i in range(n)]

@DatasetConfig.register('email',
                        help='Email dataset for pairwise node classification.')
class EmailDatasetConfig(DatasetConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--mode', choices=['pnc-full', 'nc'])

    def build_dataset(self):
        with open('dataset/email/email.txt', 'rb') as f:
            graph = nx.read_edgelist(f)

        labels = np.loadtxt('dataset/email/email_labels.txt')

        n = graph.number_of_nodes()
        feature = np.ones((n, 1))

        idx = [int(node) for node in graph.nodes()]
        label_n = labels[idx, 1]

        graph.remove_edges_from(nx.selfloop_edges(graph))
        g_nodes = list(graph.nodes)

        for m in range(n):
            graph.nodes[g_nodes[m]]['x'] = torch.from_numpy(feature[m]).float()
            graph.nodes[g_nodes[m]]['labels'] = torch.tensor(label_n[m]).long()

        dgl_g = dgl.from_networkx(graph, node_attrs=['x', 'labels'])
        srcs, dsts = dgl_g.all_edges()
        dgl_g.add_edges(dsts, srcs)

        if self.mode == 'pnc-full':
            label = torch.zeros(n, n, dtype=torch.long)
            for i in range(dgl_g.number_of_nodes()):
                for j in range(dgl_g.number_of_nodes()):
                    l = dgl_g.ndata['labels']
                    if l[i] == l[j] and i > j:
                        label[i, j] = 1
            dgl_g.ndata['pair_labels'] = label

        gl_g = GLGraph(dgl_g.to_simple())
        gl_g.gdata['name'] = f'email'

        return gl_g

@DatasetConfig.register('de-linkpred',
                        help='Link prediction datasets: NS(ns), C.ele(celegans), PB(pb).')
class DELinkPredDatasetConfig(DatasetConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--name', choices=['ns', 'celegans', 'pb'])

    def build_dataset(self):
        edges = np.loadtxt(f'dataset/link_prediction/{self.name}/edges.txt')[:, :2].astype(np.long)
        graph = nx.from_edgelist(edges)

        dgl_g = dgl.to_bidirected(dgl.from_networkx(graph))
        dgl_g.ndata['x'] = torch.ones(dgl_g.number_of_nodes(), 1)

        gl_g = GLGraph(dgl_g.to_simple())
        gl_g.gdata['name'] = self.name

        return gl_g

@DatasetConfig.register('de-nodecls',
                        help='Node classification datasets: Europe(europe-airports), USA(usa-airports).')
class DENodeClsDatasetConfig(DatasetConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--name', choices=['europe-airports', 'usa-airports'])

    def build_dataset(self):
        edges = np.loadtxt(f'dataset/node_classification/{self.name}/edges.txt')[:, :2].astype(np.long)
        labels = np.loadtxt(f'dataset/node_classification/{self.name}/labels.txt')[:, :2].astype(np.long)
        graph = nx.from_edgelist(edges)
        nx.set_node_attributes(graph, {i[0]:i[1] for i in labels.tolist()}, 'labels')
        dgl_g = dgl.to_bidirected(dgl.from_networkx(graph, node_attrs=['labels']), copy_ndata=True)
        dgl_g.ndata['x'] = torch.ones(dgl_g.number_of_nodes(), 1)

        gl_g = GLGraph(dgl_g.to_simple())
        gl_g.gdata['name'] = self.name

        return gl_g
