from typing import Optional, Callable, List

import os
import glob
import os.path as osp

import torch
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_tar, extract_zip)
from torch_geometric.utils import remove_isolated_nodes


class GSMetabolicNetworkDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(GSMetabolicNetworkDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['data.pt', 'split_dict.pt']

    @property
    def processed_file_names(self):
        return ['data.pt', 'split_dict.pt']

    def download(self):
        # Download data.pt and split_dict.pt files
        download_url('https://figshare.com/ndownloader/files/44260634?private_link=28d3130996b349e05912', self.raw_dir,
                     filename='data.pt')
        download_url('https://figshare.com/ndownloader/files/44260472?private_link=28d3130996b349e05912', self.raw_dir,
                     filename='split_dict.pt')

    def process(self):
        data_path = os.path.join(self.raw_dir, 'data.pt')
        split_dict_path = os.path.join(self.raw_dir, 'split_dict.pt')

        data = torch.load(data_path)
        split_dict = torch.load(split_dict_path)

        processed_data = data
        torch.save((processed_data, split_dict), self.processed_paths[0])

    def get_idx_split(self):
        return torch.load(self.processed_paths[1])
