import os

from pathlib import Path
from typing import (
    Any,
    Dict,
    Generator,
    Optional,
)

from src.generator.base_generator import BaseGenerator
from src.service.query_cache_service import QueryCacheService


class CachedGenerator(BaseGenerator):
    def __init__(self, generator: BaseGenerator, cache_dir: Path) -> None:
        self.generator = generator
        self.model_name = self.generator.model_name
        self.cache_service = QueryCacheService(cache_dir)

    def generate(self, prompt: str, json_mode: bool = False, json_schema: Optional[Dict[str, Any]] = None) -> str:
        cache_query = self.generator.identifier + prompt
        cache_result = self.cache_service.access(cache_query)
        if cache_result.is_hit:
            response = cache_result.value
        else:
            response = self.generator.generate(prompt=prompt, json_mode=json_mode, json_schema=json_schema)
            cache_result.set_callback(response)

        return response

    def generate_async(self, prompt: str, json_mode: bool = False) -> Generator[str, None, None]:
        raise NotImplementedError()


def maybe_apply_cache(generator: BaseGenerator, cache_dir: Path) -> CachedGenerator:
    enable_cache = os.getenv("ENABLE_QUERY_CACHE").lower() == "true"
    if enable_cache:
        return CachedGenerator(generator, cache_dir)
    return generator
