from __future__ import annotations

import fcntl
import json
import os
import time
from typing import Any, Callable, Dict, Optional


def _truthy(value: Any) -> bool:
    return str(value).strip().lower() in ("1", "true", "yes", "on")


class TraceLogger:
    """Structured trace logger with redaction and bounded JSONL rotation."""

    DEFAULT_REDACT_KEYS = (
        "api_key",
        "authorization",
        "token",
        "secret",
        "password",
        "access_key",
        "bearer",
    )

    def __init__(
        self,
        *,
        exp_id_getter: Callable[[], Optional[str]],
        experiments_dir: str,
        component: str,
        enabled: bool,
        level: str,
        max_file_bytes: int,
        keep_files: int,
        max_string_chars: int,
        max_items: int,
        redact_keys: tuple[str, ...],
    ) -> None:
        self._exp_id_getter = exp_id_getter
        self._experiments_dir = experiments_dir
        self._component = component
        self._enabled = bool(enabled)
        self._level = level
        self._max_file_bytes = max(1, int(max_file_bytes))
        self._keep_files = max(1, int(keep_files))
        self._max_string_chars = max(128, int(max_string_chars))
        self._max_items = max(8, int(max_items))
        self._redact_keys = tuple(token.lower() for token in redact_keys)

    @classmethod
    def from_env(
        cls,
        *,
        exp_id: Optional[str] = None,
        exp_id_getter: Optional[Callable[[], Optional[str]]] = None,
        state_manager: Optional[Any] = None,
        component: str,
    ) -> "TraceLogger":
        enabled_env = os.getenv("ANUM_TRACE_ENABLED", "1")
        enabled = _truthy(enabled_env)
        level = str(os.getenv("ANUM_TRACE_LEVEL", "basic")).strip().lower()
        if level not in ("basic", "full"):
            level = "basic"
        max_file_bytes = int(os.getenv("ANUM_TRACE_MAX_BYTES_PER_FILE", str(50 * 1024 * 1024)))
        keep_files = int(os.getenv("ANUM_TRACE_KEEP_FILES", "4"))
        default_max_chars = "64000" if level == "full" else "12000"
        max_string_chars = int(os.getenv("ANUM_TRACE_MAX_STRING_CHARS", default_max_chars))
        max_items = int(os.getenv("ANUM_TRACE_MAX_ITEMS", "128"))
        extra_redact = tuple(
            token.strip().lower()
            for token in str(os.getenv("ANUM_TRACE_REDACT_KEYS", "")).split(",")
            if token.strip()
        )
        redact_keys = cls.DEFAULT_REDACT_KEYS + extra_redact

        if state_manager is not None:
            experiments_dir = getattr(state_manager, "experiments_dir", None) or os.path.join(
                os.getcwd(), "experiments"
            )
        else:
            experiments_dir = os.path.join(os.getcwd(), "experiments")

        if exp_id_getter is None:
            static_exp_id = exp_id

            def _getter() -> Optional[str]:
                return static_exp_id

            exp_id_getter = _getter

        return cls(
            exp_id_getter=exp_id_getter,
            experiments_dir=experiments_dir,
            component=component,
            enabled=enabled,
            level=level,
            max_file_bytes=max_file_bytes,
            keep_files=keep_files,
            max_string_chars=max_string_chars,
            max_items=max_items,
            redact_keys=redact_keys,
        )

    @property
    def enabled(self) -> bool:
        return self._enabled

    def log(self, stream: str, event: Dict[str, Any], *, full_only: bool = False) -> None:
        if not self._enabled:
            return
        if full_only and self._level != "full":
            return
        exp_id = self._exp_id_getter()
        if not exp_id:
            return
        trace_dir = os.path.join(self._experiments_dir, exp_id, "trace")
        os.makedirs(trace_dir, exist_ok=True)
        stream_name = str(stream or "trace").strip().replace("/", "_")
        data_path = os.path.join(trace_dir, f"{stream_name}.jsonl")
        lock_path = os.path.join(trace_dir, f"{stream_name}.lock")
        record = {
            "ts": time.time(),
            "component": self._component,
            "pid": os.getpid(),
            "event": self._sanitize(event, depth=0),
        }
        payload = json.dumps(record, ensure_ascii=True, separators=(",", ":"))

        with open(lock_path, "w", encoding="utf-8") as lock_file:
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
            try:
                self._rotate_if_needed(data_path, len(payload) + 1)
                with open(data_path, "a", encoding="utf-8") as handle:
                    handle.write(payload)
                    handle.write("\n")
            finally:
                fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)

    def _rotate_if_needed(self, data_path: str, incoming_bytes: int) -> None:
        if os.path.exists(data_path):
            try:
                current_size = os.path.getsize(data_path)
            except OSError:
                current_size = 0
            if current_size + incoming_bytes < self._max_file_bytes:
                return
            for idx in range(self._keep_files - 1, 0, -1):
                src = f"{data_path}.{idx}"
                dst = f"{data_path}.{idx + 1}"
                if os.path.exists(src):
                    try:
                        os.replace(src, dst)
                    except OSError:
                        pass
            try:
                os.replace(data_path, f"{data_path}.1")
            except OSError:
                pass

    def _redact_key(self, key: str) -> bool:
        lower = str(key).lower()
        return any(token in lower for token in self._redact_keys)

    def _sanitize(self, value: Any, *, depth: int) -> Any:
        if depth >= 8:
            return "<truncated_depth>"

        if isinstance(value, dict):
            out: Dict[str, Any] = {}
            items = list(value.items())
            for idx, (key, item) in enumerate(items):
                if idx >= self._max_items:
                    out["__truncated_keys__"] = len(items) - self._max_items
                    break
                if self._redact_key(str(key)):
                    out[str(key)] = "<redacted>"
                    continue
                out[str(key)] = self._sanitize(item, depth=depth + 1)
            return out

        if isinstance(value, list):
            out = [self._sanitize(item, depth=depth + 1) for item in value[: self._max_items]]
            if len(value) > self._max_items:
                out.append({"__truncated_items__": len(value) - self._max_items})
            return out

        if isinstance(value, tuple):
            return self._sanitize(list(value), depth=depth + 1)

        if isinstance(value, bytes):
            return f"<bytes:{len(value)}>"

        if isinstance(value, str):
            if len(value) <= self._max_string_chars:
                return value
            return value[: self._max_string_chars] + "...<truncated>"

        if isinstance(value, (int, float, bool)) or value is None:
            return value

        return str(value)


class TracedToolClient:
    """Thin wrapper that emits structured trace records around tool calls."""

    def __init__(
        self,
        base_client: Any,
        trace: Optional[TraceLogger],
    ) -> None:
        self._base_client = base_client
        self._trace = trace

    def __getattr__(self, name: str) -> Any:
        return getattr(self._base_client, name)

    def call(self, name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
        started_at = time.time()
        response: Dict[str, Any]
        error: Optional[str] = None
        try:
            response = self._base_client.call(name, payload)
            return response
        except Exception as exc:
            error = str(exc)
            raise
        finally:
            if self._trace and self._trace.enabled:
                finished_at = time.time()
                event: Dict[str, Any] = {
                    "tool": name,
                    "duration_ms": int((finished_at - started_at) * 1000),
                    "payload": payload,
                }
                if error is None:
                    event["response"] = response  # type: ignore[name-defined]
                    event["status"] = (
                        response.get("status") if isinstance(response, dict) else "unknown"  # type: ignore[name-defined]
                    )
                else:
                    event["status"] = "exception"
                    event["error"] = error
                self._trace.log("tool_calls", event)


def wrap_tool_client(base_client: Any, trace: Optional[TraceLogger]) -> Any:
    if not trace or not trace.enabled:
        return base_client
    if isinstance(base_client, TracedToolClient):
        return base_client
    return TracedToolClient(base_client, trace)
