import hashlib
from pathlib import Path
import shutil


def hash_blocks(hasher, path, block_size=65536):
    with open(path, 'rb') as f:
        block = f.read(block_size)
        while len(block) > 0:
            hasher.update(block)
            block = f.read(block_size)
    return hasher


class DatasetFileCache:
    """
    A caching mechanism for storing processed files in a `_cache/` subdirectory of a dataset directory.
    Stores a checksum for the original data alongside the cached files and checks for validity before
    returning a Path object to the cache dir.

    Call `init_cache_dir()` to clear the existing directory and save the hash of the current state of
    the data directory. Then write files to the Path returned by `get_cache_dir()`.
    """

    _cache_dir_name = '_cache'
    _cache_value_filename = '_checksum.md5'

    def __init__(self, data_dir):
        self.data_dir = data_dir
        self._data_dir_hash_value = None
        self.data_cache_dir = self.data_dir / DatasetFileCache._cache_dir_name

    def get_cache_dir(self):
        if self.is_cache_valid():
            assert self.data_cache_dir.exists()
            return self.data_cache_dir
        else:
            return None

    def init_cache_dir(self):
        if self.data_cache_dir.exists():
            shutil.rmtree(self.data_cache_dir)
        self.data_cache_dir.mkdir()
        with open(self.data_cache_dir / DatasetFileCache._cache_value_filename, 'w') as f:
            f.write(self.data_dir_hash)
        return self.data_cache_dir

    def remove_cache_dir(self):
        shutil.rmtree(self.data_cache_dir)

    @property
    def data_dir_hash(self):
        if self._data_dir_hash_value is None:
            self._data_dir_hash_value = self._calculate_hash()
        return self._data_dir_hash_value

    def _calculate_hash(self):
        files = sorted(list(f for f in self.data_dir.glob('*') if f.name != DatasetFileCache._cache_dir_name))
        hasher = hashlib.md5()
        for file in files:
            hash_blocks(hasher, file)
        return hasher.hexdigest()

    def is_cache_valid(self):
        return self.data_dir_hash == self._read_saved_hash()

    def _read_saved_hash(self):
        cache_value_file = self.data_cache_dir / DatasetFileCache._cache_value_filename
        if not cache_value_file.exists():
            return None
        else:
            with open(cache_value_file) as f:
                return f.read().strip()

