"""utils
~~~~~~~~~~~~~~~~~~~~~~~~~

A grab-bag of small, reusable helpers that multiple parts of the project may
import.  Add new utilities here rather than creating many one-off modules.

Currently provides
------------------
setup_logging()
    Configure the root logger honouring the ``LOG_LEVEL`` environment
    variable.
"""

import gzip
import hashlib
import json
import logging
import os
import re
import sys
from pathlib import Path
from typing import Any, Literal, cast

import numpy as np
from diskcache import Cache

logger: logging.Logger  # initialized later

_LogLevel = Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"]


def get_project_root():
    """Search for pyproject.toml or .git folder and return the project root."""
    path = Path(__file__).resolve()
    for parent in [path, *path.parents]:
        if (parent / "pyproject.toml").exists() or (parent / ".git").exists():
            return parent
    return path


def get_cache(
    cache_name: str,
    size_limit: int = 10 * 1024**3,  # 10 GiB
    cull_limit: int = 10_000,  # purge up to 10k rows when full
) -> Cache | None:
    project_root = get_project_root()
    cache_dir = project_root / ".cache" / cache_name
    return Cache(
        directory=cache_dir,
        size_limit=size_limit,
        cull_limit=cull_limit,
    )


def get_cache_key(obj: Any):
    return hashlib.sha256(repr(obj).encode()).hexdigest()


def _resolve_env_level(default: _LogLevel = "INFO") -> _LogLevel:
    """Return a valid logging level taken from $LOG_LEVEL or *default*."""
    level = os.getenv("LOG_LEVEL", default).upper()
    if level not in {"CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"}:
        level = default
    return level  # type: ignore[return-value]


def get_logger(name: str | None = None, *, default: _LogLevel = "INFO") -> logging.Logger:
    """Configure logging (once) **and return a module/logger instance**.

    Usage
    -----
    >>> logger = setup_logging(__name__)

    Subsequent calls are cheap: the root logger is configured only on the first
    invocation; later calls simply return ``logging.getLogger(name)``.
    """

    if not logging.getLogger().handlers:
        level = _resolve_env_level(default)
        logging.basicConfig(
            level=level,
            format="%(asctime)s - %(levelname)s - %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
            stream=sys.stdout,
            force=True,
        )

    return logging.getLogger(name) if name else logging.getLogger()


logger = get_logger(__name__)


def _human_readable_size(num_bytes: float) -> str:
    """Return a human-readable file size string."""
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if num_bytes < 1024.0:
            return f"{num_bytes:.1f} {unit}"
        num_bytes /= 1024.0
    return f"{num_bytes:.1f} PB"


def load(path: str | Path) -> dict:
    path = Path(path)
    if path.suffix == ".json":
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    elif path.suffixes[-2:] == [".json", ".gz"]:
        with gzip.open(path, "rt", encoding="utf-8") as f:
            return json.load(f)
    else:
        raise AssertionError("Path must end with .json or .json.gz")


def save(data: dict, path: str | Path, create_dirs: bool = True) -> None:
    path = Path(path)
    if create_dirs:
        path.parent.mkdir(parents=True, exist_ok=True)

    if path.suffix == ".json":
        with open(path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
    elif path.suffixes[-2:] == [".json", ".gz"]:
        with gzip.open(path, "wt", encoding="utf-8") as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
    else:
        raise AssertionError("Path must end with .json or .json.gz")

    size = path.stat().st_size
    logger.info(f"Saved {path} ({_human_readable_size(size)})")


def safe_filename(s: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_]", "_", s)


def get_eval_path(translations_path: str, evaluator: str) -> tuple[str, Path]:
    tp = Path(translations_path)
    if not tp.exists():
        raise FileNotFoundError(f"Translations file not found: {tp}")

    eval_dir = tp.parent / safe_filename(evaluator)
    return tp.stem, eval_dir


def find_project_root(start: str | Path | None = None) -> Path:
    """Return the top-level project directory.

    Starting from *start* (defaults to the current working directory), walk
    upwards until a directory containing either ``pyproject.toml`` **or** a
    ``.git`` folder is found.  If no such marker is encountered, the search
    stops at the filesystem root and the original starting directory is
    returned.
    """

    path = Path(start or Path.cwd()).resolve()
    for parent in [path, *path.parents]:
        if (parent / "pyproject.toml").exists() or (parent / ".git").exists():
            return parent

    logger.warning("Could not find project root; falling back to current working directory")
    return path


def generate_permutations(chunks: list[str], n_permutations: int, seed: int = 0) -> list[list[int]]:
    """
    Generate random permutations of chunk indices based on seed and chunk content.

    Args:
        chunks: List of text chunks to permute.
        n_permutations: Number of permutations to generate (excluding identity).
        seed: Random seed for reproducibility, combined with content hash.

    Returns:
        List of permutations (excluding the identity), each permutation is a list of indices.
    """
    n_chunks = len(chunks)

    # Check if requested permutations is feasible
    import math

    max_permutations = math.factorial(n_chunks) - 1  # -1 because we exclude identity
    assert n_permutations <= max_permutations, (
        f"Cannot generate {n_permutations} distinct permutations for {n_chunks} chunks (max: {max_permutations})"
    )

    # Create a hash of the chunk content to make permutations content-dependent
    import hashlib

    content_hash = hashlib.sha256("".join(chunks).encode()).hexdigest()
    # Convert hex to int and use modulo to get a reasonable additional seed component
    content_seed = int(content_hash[:8], 16) % (2**31)

    # Combine seed with content hash for unique but reproducible randomness
    combined_seed = seed + content_seed
    rng = np.random.default_rng(combined_seed)

    # Generate distinct permutations
    permutations: list[list[int]] = []
    seen = set()

    # Exclude the identity permutation from generation – we'll handle it separately
    identity_perm = tuple(range(n_chunks))
    seen.add(identity_perm)

    # Generate n_permutations distinct non-identity permutations
    while len(permutations) < n_permutations:
        perm: list[int] = cast(list[int], rng.permutation(n_chunks).tolist())
        perm_tuple = tuple(perm)
        if perm_tuple not in seen:
            permutations.append(perm)
            seen.add(perm_tuple)

    return permutations


def extract_last_json(s: str) -> dict[str, Any] | None:
    """Finds the last JSON list or dict in a string, returns None if not found."""
    s = s[: max(s.rfind("]"), s.rfind("}")) + 1]
    for i, c in enumerate(s):
        if c == "[" or c == "{":
            try:
                return json.loads(s[i:])
            except Exception:
                pass
    return None
