from typing import Optional, Callable, List

import os
import glob
import os.path as osp

import torch
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_tar, extract_zip)
from torch_geometric.utils import remove_isolated_nodes


class SingleDataset(InMemoryDataset):
    r"""This dataset is used to store a single graph
    """

    def __init__(self, root: str, pkl_dir: str, transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None,
                 pre_filter: Optional[Callable] = None):
        self.pkl_dir = pkl_dir
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> List[str]:
        return ""

    @property
    def processed_file_names(self) -> List[str]:
        return ['data.pt', 'split_dict.pt']

    def download(self):
        return

    def _read_data_list(self):
        raise NotImplemented("This method is implemented outside")

    def process(self):
        train_data_list, val_data_list, test_data_list = self._read_data_list()
        split_dict = {'train': [], 'valid': [], 'test': []}

        LEN_MODELS = 1
        TRAIN_SIZE = 1
        VAL_SIZE = 1
        TEST_SIZE = 1

        split_dict['train'] = [0]
        split_dict['valid'] = [0]
        split_dict['test'] = [0]

        data_list = train_data_list + val_data_list + test_data_list

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        torch.save(self.collate(data_list), self.processed_paths[0])
        torch.save(split_dict, self.processed_paths[1])

    def get_idx_split(self):
        return torch.load(self.processed_paths[1])
