import os
import torch as th
import torch.nn as nn
import tqdm


class PBar(object):
    def __enter__(self):
        self.t = None
        return self

    def __call__(self, blockno, readsize, totalsize):
        if self.t is None:
            self.t = tqdm.tqdm(total=totalsize)
        self.t.update(readsize)

    def __exit__(self, exc_type, exc_value, traceback):
        self.t.close()


class AminerDataset(object):
    """
    Download Aminer Dataset from Amazon S3 bucket. 
    """
    def __init__(self, path):

        self.url = 'https://data.dgl.ai/dataset/aminer.zip'

        if not os.path.exists(os.path.join(path, 'aminer.txt')):
            print('File not found. Downloading from', self.url)
            self._download_and_extract(path, 'aminer.zip')
        self.fn = os.path.join(path, 'aminer.txt')

    def _download_and_extract(self, path, filename):
        import shutil, zipfile, zlib
        from tqdm import tqdm
        import urllib.request

        fn = os.path.join(path, filename)
        with PBar() as pb:
            urllib.request.urlretrieve(self.url, fn, pb)
        print('Download finished. Unzipping the file...')

        with zipfile.ZipFile(fn) as zf:
            zf.extractall(path)
        print('Unzip finished.')


class CustomDataset(object):
    """
    Custom dataset generated by sampler.py (e.g. NetDBIS)
    """
    def __init__(self, path):
        self.fn = path
