from graph_learning.data_setting import DataTransform, DataSettingConfig
import pickle
import argparse
from pathlib import Path

@DataSettingConfig.register('cache',
                            help='Cache the current data if exists, load the cache else.')
class CacheConfig(DataSettingConfig):
    @property
    def builder(self):
        return CachePickle

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--cache-root', required=True,
                            help='cache folder.')
        parser.add_argument('--cache-name', required=True,
                            help='cache file name.')

class CachePickle(DataTransform):
    def __init__(self, cache_root, cache_name):
        self.cache_root = Path(cache_root)
        self.cache_name = cache_name

    def load_data(self):
        with open(self.cache_root/self.cache_name, 'rb') as f:
            data = pickle.load(f)
        return data

    def save_data(self, data):
        self.cache_root.mkdir(parents=True, exist_ok=True)
        with open(self.cache_root/self.cache_name, 'wb') as f:
            pickle.dump(data, f)

    def transform(self, data):
        if data is None:
            data = self.load_data()
        else:
            self.save_data(data)
        return data
