import hashlib

from dataclasses import dataclass
from pathlib import Path
from typing import (
    Callable,
    Optional,
)

from loguru import logger


@dataclass
class CacheAccessResult:
    is_hit: bool
    value: Optional[str]
    set_callback: Callable[[str], None]


class QueryCacheService:
    def __init__(self, cache_dir: Path) -> None:
        self.cache_dir = cache_dir
        self.cache_dir.mkdir(parents=True, exist_ok=True)

        cache_file_postfix = ".value.txt"
        self.cached_hash = set(
            x.name.replace(cache_file_postfix, "") for x in self.cache_dir.glob("*" + cache_file_postfix)
        )

    def _hash_query(self, query: str) -> str:
        return hashlib.sha256(query.encode("utf-8")).hexdigest()

    def value_path(self, hash: str) -> Path:
        return self.cache_dir / f"{hash}.value.txt"

    def query_path(self, hash: str) -> Path:
        return self.cache_dir / f"{hash}.query.txt"

    def read(self, query: str, hash: str) -> str:
        with self.query_path(hash).open() as f:
            if f.read().strip() != query.strip():
                logger.error(f"[QueryCacheService] query mismatch for {self.query_path}")
        with self.value_path(hash).open() as f:
            return f.read()

    def write(self, query: str, hash: str, value: str) -> None:
        with self.query_path(hash).open("w") as f:
            f.write(query)
        with self.value_path(hash).open("w") as f:
            f.write(value)

    def access(self, query: str) -> CacheAccessResult:
        hash = self._hash_query(query)

        if hash in self.cached_hash:
            logger.info("cache hit")
            is_hit = True
            value = self.read(query, hash)
        else:
            logger.info("cache miss")
            is_hit = False
            value = None

        def set_callback(value: str) -> None:
            if not value.strip():
                logger.warning(f"Whitespace text passed to set_callback. Hash: {hash}")
            self.write(query, hash, value)
            self.cached_hash.add(hash)

        return CacheAccessResult(is_hit=is_hit, value=value, set_callback=set_callback)
