import abc
import hashlib
import random
import time
from pathlib import Path
from typing import Any

import utils
from diskcache import Cache

logger = utils.get_logger(__name__)


def get_dir_size(path):
    total = 0
    for p in path.rglob("*"):
        if p.is_file():
            total += p.stat().st_size
    return total


def human_readable_size(size_bytes):
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if size_bytes < 1024:
            return f"{size_bytes:.2f} {unit}"
        size_bytes /= 1024
    return f"{size_bytes:.2f} PB"


class LM(abc.ABC):
    """Abstract base-class for simple one-prompt language models that can cache responses.

    Concrete subclasses wrap specific model back-ends (OpenAI, OpenRouter, etc.)
    and expose a uniform :py:meth:`generate` interface.
    """

    # Single global cache shared by *all* LM subclasses (initially ``None``)
    _global_cache: Cache | None = None
    warn_cache_misses: bool = False  # turn this on if you want to stop all LMs and only use cache
    error_on_cache_miss: bool = False  # turn this on if you want to stop all LMs and only use cache
    # _mem_cache: dict[str, Any] | None = None  # turn this on if you want to store all calls

    global_cache_path = ".cache"

    def __init__(
        self, model_name: str = "", cache: bool = True, context_length: int | None = None
    ) -> None:
        """Initialise the LM base.

        Parameters
        ----------
        model_name
            Name of the model to use.  This should be unique to the model. This enables
            us to cache OpenAI and OpenRouter models together if they are the same model, e.g.,
            two different ways to access gpt-4o.
        cache
            Enable (``True``) or disable (``False``) on-disk caching.
        """
        self._cache_enabled: bool = cache
        self.model_name = model_name
        # Maximum supported context window (in tokens) for this model instance.
        # Subclasses may set this based on concrete model capabilities.
        self.context_length: int | None = context_length

        if self._cache_enabled and LM._global_cache is None:  # first-time setup
            project_root = utils.find_project_root(Path(__file__).resolve())
            cache_dir = project_root / LM.global_cache_path
            cache_dir.mkdir(parents=True, exist_ok=True)
            LM._global_cache = Cache(
                directory=cache_dir,
                size_limit=10 * 1024**3,  # 10 GiB
                cull_limit=10_000,  # purge up to 10k rows when full
            )
            # Compute and print total cache size in human-friendly format

            total_cache_size = get_dir_size(cache_dir)
            logger.info(
                f"LM caching enabled at {cache_dir}: {human_readable_size(total_cache_size)}"
            )

        # Convenience handle on the instance
        self._cache: Cache | None = LM._global_cache if self._cache_enabled else None

    # def is_cached(
    #     self,
    #     prompt: str,
    #     system_message: str | None = None,
    #     seed: str | int | None = None,
    #     max_retries: int | None = 3,
    #     **kwargs: Any,
    # ) -> bool:
    #     """Check if the prompt is cached."""
    #     return self._cache is not None and self.cached_get(
    #         self._make_key(prompt, system_message, str(seed), kwargs)) is None

    def generate(
        self,
        prompt: str,
        system_message: str | None = None,
        seed: str | int | None = None,
        max_retries: int | None = 3,
        **kwargs: Any,
    ) -> Any:
        """Generate *one* response and return it.

        The method transparently handles on-disk caching *and* simple retry
        logic.  All additional keyword arguments are forwarded to the backend
        implementation and also form part of the cache key so that different
        parameter combinations are cached independently.
        """

        # Default: use details-aware path and return only the primary result
        result, _ = self.generate_with_details(
            prompt=prompt,
            system_message=system_message,
            seed=seed,
            max_retries=max_retries,
            **kwargs,
        )
        return result

    def generate_with_details(
        self,
        prompt: str,
        system_message: str | None = None,
        seed: str | int | None = None,
        max_retries: int | None = 3,
        **kwargs: Any,
    ) -> tuple[Any, dict]:
        """Generate one response and return ``(result, details_dict)``.

        - ``result``: the legacy return value of ``_generate`` (str or dict).
        - ``details_dict``: provider-specific metadata (e.g., full completion dict).
        """

        seed_str = str(seed)

        key = None
        cached = None
        if self._cache_enabled:
            key = self._make_key(prompt, system_message, seed_str, kwargs)
            cached = self.cached_get(key)
            if cached is not None:
                # Back-compat with older cache entries: wrap strings/dicts
                if isinstance(cached, tuple) and len(cached) == 2:
                    return cached  # (result, details)
                if isinstance(cached, (str, dict)):
                    return (cached, {})
                # Anything else: best-effort stringification
                try:
                    return (cached.get("response"), cached)
                except Exception:
                    return (str(cached), {})

            if LM.warn_cache_misses:
                logger.warning(f"LM.warn_cache_misses=True but cache missing {prompt[:100]=!r}")

            if LM.error_on_cache_miss:
                raise RuntimeError(
                    f"LM.error_on_cache_miss=True but cache missing {len(system_message or '')=} {len(prompt)=} {prompt[:500]=!r}"
                )

        attempts = max(1, max_retries or 1)
        last_err: Exception | None = None
        for _ in range(attempts):
            try:
                _t0 = time.perf_counter()
                result, details = self._generate_with_details(
                    prompt=prompt,
                    system_message=system_message,
                    seed_str=seed_str,
                    **kwargs,
                )
                try:
                    if isinstance(details, dict) and "duration" not in details:
                        details["duration"] = time.perf_counter() - _t0
                except Exception:
                    pass
                if self._cache_enabled and key is not None:
                    self.cached_set(key, (result, details))
                return (result, details)
            except Exception as e:
                last_err = e
                logger.warning(f"{self.model_name} failed on {prompt[:50]!r}: {e}. Retrying…")
                continue

        raise RuntimeError(f"{self.model_name} failed on prompt {prompt[:1000]!r}: {last_err}")

    # Optional detailed generation; subclasses may override to add usage/metadata.
    def _generate_with_details(
        self,
        prompt: str,
        system_message: str | None,
        seed_str: str,
        **kwargs: Any,
    ) -> tuple[Any, dict]:
        """Default implementation: call ``_generate`` and return (result, {})."""
        result = self._generate(prompt, system_message, seed_str, **kwargs)
        return (result, {})

    # ------------------------------------------------------------------
    # Abstract API that concrete subclasses must implement
    # ------------------------------------------------------------------
    @abc.abstractmethod
    def _generate(
        self,
        prompt: str,
        system_message: str | None,
        seed_str: str,
        **kwargs,
    ) -> Any:
        """Generate a response for a single prompt."""
        pass

    def truncate_to_token_len(self, text: str, max_tokens: int) -> str:
        """Return ``text`` truncated to at most ``max_tokens`` according to the model tokenizer.

        Subclasses can implement this using an appropriate tokenizer for the
        underlying model family.
        """
        raise NotImplementedError

    # ------------------------------------------------------------------
    # Cache helpers
    # ------------------------------------------------------------------
    def cached_get(self, key: str) -> Any | None:
        """Retrieve *key* from the disk cache if present."""
        if self._cache is None:
            return None
        return self._cache.get(key)

    def cached_set(self, key: str, value: Any, expire: int | None = None) -> None:
        """Store *value* under *key* in the disk cache."""
        if self._cache is not None:
            self._cache.set(key, value, expire=expire)

    # ------------------------------------------------------------------
    # Key builder helper
    # ------------------------------------------------------------------
    def _make_key(
        self,
        prompt: str,
        system_message: str | None,
        seed_str: str,
        kwargs: dict[str, Any],
        hash: bool = True,
    ) -> str:
        """Create a stable cache key for a single prompt call."""
        if not self._cache_enabled:
            return ""

        key_tuple = (
            self.model_name,
            system_message,
            prompt,
            seed_str,
            tuple(sorted(kwargs.items())),  # canonical order
        )
        if hash:
            return hashlib.sha256(repr(key_tuple).encode()).hexdigest()
        else:
            return str(key_tuple)


class RandomLM(LM):
    def __init__(self, model_name: str, cache: bool = True) -> None:
        super().__init__(model_name=model_name, cache=cache)

    def _generate(
        self, prompt: str, system_message: str | None, seed_str: str, **kwargs: Any
    ) -> str:
        r = random.random()
        return f"Response {r=} for {prompt=} with {system_message=} and {seed_str=} and {kwargs=}"

    def truncate_to_token_len(self, text: str, max_tokens: int) -> str:
        return text


def test_lm_cache():
    """Extended regression tests for the global LM cache.

    The test exercises the following scenarios:

    1. *Caching disabled* – duplicate calls always regenerate.
    2. *Caching enabled* – same call returns cached result.
    3. *Seed affects key* – changing the seed bypasses the cache.
    4. *Cross-instance sharing* – two distinct objects with the same
       ``model_name`` hit the same cache.
    5. *Model name isolation* – caches for different ``model_name`` values do
       not collide.
    """

    prompt = "p1"

    # 1) Caching disabled ---------------------------------------------------
    no_cache = RandomLM("no-cache-model", cache=False)  # type: ignore[abstract]
    r1 = no_cache.generate(prompt, seed=1)
    r1_again = no_cache.generate(prompt, seed=1)
    assert r1 != r1_again, "Caching disabled: repeated call should differ"

    # 2) Caching enabled (same instance) -----------------------------------
    cached = RandomLM("cache-model", cache=True)  # type: ignore[abstract]
    rc1 = cached.generate(prompt, seed=2)
    rc1_again = cached.generate(prompt, seed=2)
    assert rc1 == rc1_again, "Same seed + same instance should return cache"

    # 3) Seed changes key ---------------------------------------------------
    rc1_diff_seed = cached.generate(prompt, seed=3)
    assert rc1 != rc1_diff_seed, "Different seed should bypass cache"

    # 4) Cross-instance cache sharing --------------------------------------
    cached_2 = RandomLM("cache-model", cache=True)  # type: ignore[abstract]  # same model_name
    rc2 = cached_2.generate(prompt, seed=2)
    assert rc2 == rc1, "Cache should persist across instances"

    # 5) Model-name isolation ---------------------------------------------
    other_model = RandomLM("other-model", cache=True)
    ro = other_model.generate(prompt, seed=2)
    assert ro != rc1, "Different model_name should have independent keys"

    print("All cache tests passed!")


if __name__ == "__main__":
    test_lm_cache()
