import os
import pathlib

import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import InMemoryDataset

from src.datasets.abstract_dataset import (AbstractDataModule,
                                           AbstractDatasetInfos)


class NetDataset(InMemoryDataset):
    def __init__(self, dataset_name, root, transform=None, pre_transform=None, pre_filter=None):
        self.dataset_name = dataset_name
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['processed.pt']

    def process(self):
        edgelist_path = pathlib.Path(__file__).parents[2] / "raw_data"
        nxg: nx.Graph = nx.read_edgelist(edgelist_path / f'{self.dataset_name}.txt', comments='p')
        nx.set_node_attributes(nxg, 1., 'dummy')
        nx.set_edge_attributes(nxg, [0., 1.], 'dummy')
        data = torch_geometric.utils.from_networkx(
            nxg, group_node_attrs=['dummy'], group_edge_attrs=['dummy'])
        data.y = torch.zeros([1, 0], dtype=torch.float)
        data_list = [data]
        torch.save(self.collate(data_list), self.processed_paths[0])


class NetDataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.cfg = cfg
        self.datadir = cfg.dataset.datadir
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)

        net_dataset = NetDataset(dataset_name=self.cfg.dataset.name, root=root_path)

        datasets = {'train': net_dataset, 'val': net_dataset, 'test': net_dataset}

        super().__init__(cfg, datasets)
        self.inner = self.train_dataset

    def __getitem__(self, item):
        return self.inner[item]


class NetDatasetInfos(AbstractDatasetInfos):
    def __init__(self, datamodule):
        self.datamodule = datamodule
        self.name = 'nx_graphs'
        self.n_nodes = self.datamodule.node_counts()
        self.node_types = torch.tensor([1])               # There are no node types
        self.edge_types = self.datamodule.edge_counts()
        super().complete_infos(self.n_nodes, self.node_types)
