from typing import Dict

from torch_geometric.data import extract_zip

import src.data.transforms as base_t
from src.data.transforms import (
    downloads as down_t
)
import src.data.dataset as ds


class URLDatasetDownloadPipeline:

    def __init__(self, dataset_name: str, dataset_url: str=None, dataset_filenames_map: Dict=None):
        assert dataset_url is not None or dataset_filenames_map is not None, 'Both dataset_url and dataset_options must be provided'
        self.dataset_url = dataset_url
        self.filename = dataset_filenames_map[dataset_name]
    

    def __call__(self) -> base_t.DFPipeline:
        
        pipeline = base_t.DFPipeline(

            output_files = {
                'data_file': f'{self.filename}'
            },
            
            transforms=[
                #########################  FOLDERS SETUP  #########################
                base_t.DFAddDatafield('raw_path',       ds.STD_FOLDER_RAW),
                base_t.DFSPecializePath([ds.KEY_ROOT, 'raw_path'], 'raw_path'),
                
                base_t.DFCreateFolder(
                    destination_df =			'raw_path'
                ),

                ##################  DOWNLOAD AND UNZIP DATASETS  ##################
                down_t.DFDownloadFromURL(
                    download_to_df =	'raw_path',
                    url={
                        'data_file': self.dataset_url + self.filename
                    }
                )
            ]
        )

        return pipeline



_SPECTRE_DATASET_NAME_TO_FILENAME = {
    'community-20': 'community_12_21_100.pt',
    'planar': 'planar_64_200.pt',
    'sbm': 'sbm_200.pt'
}

_SPECTRE_DATASET_URL = 'https://github.com/KarolisMart/SPECTRE/raw/main/data/'


class SpectreDatasetDownloadPipeline(URLDatasetDownloadPipeline):

    def __init__(self, dataset_name: str):
        super().__init__(
            dataset_name=dataset_name,
            dataset_url=_SPECTRE_DATASET_URL,
            dataset_filenames_map=_SPECTRE_DATASET_NAME_TO_FILENAME
        )
    

_CDGS_DATASET_NAME_TO_FILENAME = {
    'enzymes': 'ENZYMES.pkl',
    'ego': 'Ego.pkl',
    'ego-small': 'Ego_small.pkl'
}

_CDGS_DATASET_URL = 'https://github.com/GRAPH-0/GraphGDP/raw/main/data/raw/'

class CDGSDatasetDownloadPipeline(URLDatasetDownloadPipeline):

    def __init__(self, dataset_name: str):
        super().__init__(
            dataset_name=dataset_name,
            dataset_url=_CDGS_DATASET_URL,
            dataset_filenames_map=_CDGS_DATASET_NAME_TO_FILENAME
        )
