"""Decorators for Expected GradCAM.

This module provides utility decorators for caching, timing, validation,
and gradient control. These decorators are designed to work seamlessly
with PyTorch tensors and GPU operations.

Example:
    >>> @timed(log=True)
    ... def expensive_computation(x):
    ...     return x @ x.T
    >>>
    >>> @validate_input(x={"shape": (None, 3, 224, 224), "dtype": torch.float32})
    ... def process_image(x):
    ...     return model(x)
"""

from __future__ import annotations

import functools
import hashlib
import logging
import pickle
import threading
import time
import weakref
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Generic,
    ParamSpec,
    TypeVar,
    overload,
)

import torch
from torch import Tensor

if TYPE_CHECKING:
    from collections.abc import Hashable

logger = logging.getLogger(__name__)

P = ParamSpec("P")
R = TypeVar("R")
F = TypeVar("F", bound=Callable[..., Any])


# =============================================================================
# Timing Decorator
# =============================================================================


@dataclass
class TimingStats:
    """Statistics for timed function calls.

    Attributes:
        total_time: Total accumulated time in seconds.
        call_count: Number of times the function was called.
        min_time: Minimum execution time.
        max_time: Maximum execution time.
    """

    total_time: float = 0.0
    call_count: int = 0
    min_time: float = float("inf")
    max_time: float = 0.0
    _times: list[float] = field(default_factory=list)

    @property
    def avg_time(self) -> float:
        """Average execution time."""
        return self.total_time / self.call_count if self.call_count > 0 else 0.0

    def record(self, elapsed: float) -> None:
        """Record a timing measurement."""
        self.total_time += elapsed
        self.call_count += 1
        self.min_time = min(self.min_time, elapsed)
        self.max_time = max(self.max_time, elapsed)
        self._times.append(elapsed)

    def reset(self) -> None:
        """Reset all statistics."""
        self.total_time = 0.0
        self.call_count = 0
        self.min_time = float("inf")
        self.max_time = 0.0
        self._times.clear()

    def __repr__(self) -> str:
        if self.call_count == 0:
            return "TimingStats(no calls)"
        return (
            f"TimingStats(calls={self.call_count}, "
            f"avg={self.avg_time*1000:.2f}ms, "
            f"min={self.min_time*1000:.2f}ms, "
            f"max={self.max_time*1000:.2f}ms)"
        )


class timed:
    """Decorator to measure function execution time.

    Supports both CPU and GPU operations with proper CUDA synchronization.
    Can log timing information or collect statistics for later analysis.

    Args:
        log: Whether to log timing info after each call.
        collect_stats: Whether to collect timing statistics.
        sync_cuda: Whether to synchronize CUDA before timing.
        name: Custom name for logging (defaults to function name).

    Example:
        >>> @timed(log=True)
        ... def compute(x):
        ...     return x @ x.T
        >>>
        >>> @timed(collect_stats=True)
        ... def batch_process(batch):
        ...     return model(batch)
        >>>
        >>> batch_process.stats  # Access timing stats
        TimingStats(calls=100, avg=15.23ms, ...)
    """

    def __init__(
        self,
        log: bool = False,
        collect_stats: bool = True,
        sync_cuda: bool = True,
        name: str | None = None,
    ) -> None:
        self.log = log
        self.collect_stats = collect_stats
        self.sync_cuda = sync_cuda
        self.name = name

    def __call__(self, func: F) -> F:
        stats = TimingStats()
        name = self.name or func.__name__

        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            # Synchronize CUDA if needed
            if self.sync_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()

            start = time.perf_counter()

            try:
                result = func(*args, **kwargs)
            finally:
                # Synchronize CUDA after if needed
                if self.sync_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()

                elapsed = time.perf_counter() - start

                if self.collect_stats:
                    stats.record(elapsed)

                if self.log:
                    logger.info(f"{name}: {elapsed*1000:.2f}ms")

            return result

        # Attach stats to wrapper
        wrapper.stats = stats  # type: ignore[attr-defined]
        wrapper.reset_stats = stats.reset  # type: ignore[attr-defined]

        return wrapper  # type: ignore[return-value]


def gpu_sync(func: F) -> F:
    """Decorator to synchronize CUDA before and after function execution.

    Useful for accurate timing of GPU operations or ensuring operations
    complete before proceeding.

    Example:
        >>> @gpu_sync
        ... def gpu_compute(x):
        ...     return torch.matmul(x, x.T)
    """

    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        result = func(*args, **kwargs)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        return result

    return wrapper  # type: ignore[return-value]


# =============================================================================
# Caching Decorator
# =============================================================================


def _tensor_hash(tensor: Tensor) -> str:
    """Compute hash for a tensor."""
    # Use data pointer and shape for fast hashing
    # For small tensors, include actual data
    if tensor.numel() <= 1000:
        data_bytes = tensor.detach().cpu().numpy().tobytes()
        return hashlib.md5(data_bytes).hexdigest()
    else:
        # For large tensors, use statistical signature
        stats = (
            tensor.shape,
            tensor.dtype,
            float(tensor.mean()),
            float(tensor.std()),
            float(tensor.min()),
            float(tensor.max()),
        )
        return hashlib.md5(pickle.dumps(stats)).hexdigest()


def _make_hashable(obj: Any) -> Hashable:
    """Convert object to hashable representation."""
    if isinstance(obj, Tensor):
        return ("tensor", _tensor_hash(obj))
    elif isinstance(obj, dict):
        return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
    elif isinstance(obj, (list, tuple)):
        return tuple(_make_hashable(x) for x in obj)
    elif isinstance(obj, set):
        return frozenset(_make_hashable(x) for x in obj)
    elif hasattr(obj, "__hash__") and obj.__hash__ is not None:
        try:
            hash(obj)
            return obj
        except TypeError:
            return str(obj)
    else:
        return str(obj)


class LRUCache(Generic[R]):
    """Thread-safe LRU cache with TTL support.

    Args:
        maxsize: Maximum number of entries.
        ttl: Time-to-live in seconds (None for no expiry).
    """

    def __init__(self, maxsize: int = 128, ttl: float | None = None) -> None:
        self.maxsize = maxsize
        self.ttl = ttl
        self._cache: OrderedDict[Hashable, tuple[R, float]] = OrderedDict()
        self._lock = threading.RLock()
        self.hits = 0
        self.misses = 0

    def get(self, key: Hashable) -> tuple[bool, R | None]:
        """Get item from cache."""
        with self._lock:
            if key not in self._cache:
                self.misses += 1
                return False, None

            value, timestamp = self._cache[key]

            # Check TTL
            if self.ttl is not None and time.time() - timestamp > self.ttl:
                del self._cache[key]
                self.misses += 1
                return False, None

            # Move to end (most recently used)
            self._cache.move_to_end(key)
            self.hits += 1
            return True, value

    def set(self, key: Hashable, value: R) -> None:
        """Set item in cache."""
        with self._lock:
            # Remove oldest if at capacity
            while len(self._cache) >= self.maxsize:
                self._cache.popitem(last=False)

            self._cache[key] = (value, time.time())

    def clear(self) -> None:
        """Clear the cache."""
        with self._lock:
            self._cache.clear()
            self.hits = 0
            self.misses = 0

    @property
    def hit_rate(self) -> float:
        """Cache hit rate."""
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0.0

    def __len__(self) -> int:
        return len(self._cache)


class cached:
    """Decorator for caching function results.

    Supports tensor arguments via content hashing and provides
    TTL-based cache expiration.

    Args:
        maxsize: Maximum cache size.
        ttl: Time-to-live for cache entries in seconds.
        typed: Whether to consider argument types in cache key.
        key_func: Optional custom key function.

    Example:
        >>> @cached(maxsize=100, ttl=300)
        ... def expensive_computation(x: Tensor, config: dict) -> Tensor:
        ...     # Complex computation
        ...     return result
        >>>
        >>> expensive_computation.cache_info()
        CacheInfo(hits=50, misses=10, size=10, hit_rate=0.83)
        >>>
        >>> expensive_computation.cache_clear()
    """

    def __init__(
        self,
        maxsize: int = 128,
        ttl: float | None = None,
        typed: bool = False,
        key_func: Callable[..., Hashable] | None = None,
    ) -> None:
        self.maxsize = maxsize
        self.ttl = ttl
        self.typed = typed
        self.key_func = key_func

    def __call__(self, func: Callable[P, R]) -> Callable[P, R]:
        cache: LRUCache[R] = LRUCache(maxsize=self.maxsize, ttl=self.ttl)

        @functools.wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
            # Build cache key
            if self.key_func is not None:
                key = self.key_func(*args, **kwargs)
            else:
                key_parts: list[Any] = [_make_hashable(arg) for arg in args]
                key_parts.extend(
                    (k, _make_hashable(v)) for k, v in sorted(kwargs.items())
                )
                if self.typed:
                    key_parts.extend(type(arg).__name__ for arg in args)
                key = tuple(key_parts)

            # Check cache
            found, value = cache.get(key)
            if found:
                return value  # type: ignore[return-value]

            # Compute and cache
            result = func(*args, **kwargs)
            cache.set(key, result)
            return result

        def cache_info() -> dict[str, Any]:
            return {
                "hits": cache.hits,
                "misses": cache.misses,
                "size": len(cache),
                "maxsize": cache.maxsize,
                "hit_rate": cache.hit_rate,
            }

        wrapper.cache = cache  # type: ignore[attr-defined]
        wrapper.cache_info = cache_info  # type: ignore[attr-defined]
        wrapper.cache_clear = cache.clear  # type: ignore[attr-defined]

        return wrapper


# =============================================================================
# Validation Decorators
# =============================================================================


@dataclass
class TensorSpec:
    """Specification for tensor validation.

    Attributes:
        shape: Expected shape (use None for dynamic dimensions).
        dtype: Expected dtype.
        device: Expected device type ('cpu', 'cuda', or None for any).
        requires_grad: Whether gradient should be required.
        min_val: Minimum allowed value.
        max_val: Maximum allowed value.
        ndim: Expected number of dimensions.
    """

    shape: tuple[int | None, ...] | None = None
    dtype: torch.dtype | None = None
    device: str | None = None
    requires_grad: bool | None = None
    min_val: float | None = None
    max_val: float | None = None
    ndim: int | None = None

    def validate(self, tensor: Tensor, name: str) -> list[str]:
        """Validate tensor against spec.

        Returns:
            List of validation error messages (empty if valid).
        """
        errors = []

        # Check ndim
        if self.ndim is not None and tensor.ndim != self.ndim:
            errors.append(f"{name}: expected {self.ndim}D, got {tensor.ndim}D")

        # Check shape
        if self.shape is not None:
            if len(self.shape) != tensor.ndim:
                errors.append(
                    f"{name}: expected {len(self.shape)}D shape, "
                    f"got {tensor.ndim}D {tuple(tensor.shape)}"
                )
            else:
                for i, (expected, actual) in enumerate(zip(self.shape, tensor.shape)):
                    if expected is not None and expected != actual:
                        errors.append(
                            f"{name}: dimension {i} expected {expected}, got {actual}"
                        )

        # Check dtype
        if self.dtype is not None and tensor.dtype != self.dtype:
            errors.append(f"{name}: expected dtype {self.dtype}, got {tensor.dtype}")

        # Check device
        if self.device is not None:
            device_type = tensor.device.type
            if device_type != self.device:
                errors.append(
                    f"{name}: expected device {self.device}, got {device_type}"
                )

        # Check requires_grad
        if self.requires_grad is not None:
            if tensor.requires_grad != self.requires_grad:
                errors.append(
                    f"{name}: expected requires_grad={self.requires_grad}, "
                    f"got {tensor.requires_grad}"
                )

        # Check value range
        if self.min_val is not None:
            actual_min = float(tensor.min())
            if actual_min < self.min_val:
                errors.append(
                    f"{name}: min value {actual_min:.4f} < {self.min_val}"
                )

        if self.max_val is not None:
            actual_max = float(tensor.max())
            if actual_max > self.max_val:
                errors.append(
                    f"{name}: max value {actual_max:.4f} > {self.max_val}"
                )

        return errors


def validate_input(
    **specs: TensorSpec | dict[str, Any],
) -> Callable[[F], F]:
    """Decorator to validate function input tensors.

    Args:
        **specs: Mapping of parameter names to TensorSpec or spec dicts.

    Raises:
        ValueError: If validation fails.

    Example:
        >>> @validate_input(
        ...     x=TensorSpec(shape=(None, 3, 224, 224), dtype=torch.float32),
        ...     mask={"ndim": 2, "dtype": torch.bool},
        ... )
        ... def process(x: Tensor, mask: Tensor) -> Tensor:
        ...     return x * mask.unsqueeze(-1).unsqueeze(-1)
    """

    # Convert dicts to TensorSpec
    tensor_specs: dict[str, TensorSpec] = {}
    for name, spec in specs.items():
        if isinstance(spec, TensorSpec):
            tensor_specs[name] = spec
        else:
            tensor_specs[name] = TensorSpec(**spec)

    def decorator(func: F) -> F:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            # Get function signature
            import inspect

            sig = inspect.signature(func)
            bound = sig.bind(*args, **kwargs)
            bound.apply_defaults()

            # Validate specified parameters
            all_errors = []
            for name, spec in tensor_specs.items():
                if name in bound.arguments:
                    value = bound.arguments[name]
                    if isinstance(value, Tensor):
                        errors = spec.validate(value, name)
                        all_errors.extend(errors)
                    elif value is not None:
                        all_errors.append(
                            f"{name}: expected Tensor, got {type(value).__name__}"
                        )

            if all_errors:
                raise ValueError(
                    f"Input validation failed:\n  " + "\n  ".join(all_errors)
                )

            return func(*args, **kwargs)

        return wrapper  # type: ignore[return-value]

    return decorator


def validate_output(spec: TensorSpec | dict[str, Any]) -> Callable[[F], F]:
    """Decorator to validate function output tensor.

    Args:
        spec: TensorSpec or spec dict for output validation.

    Raises:
        ValueError: If validation fails.

    Example:
        >>> @validate_output(TensorSpec(ndim=4, min_val=0.0, max_val=1.0))
        ... def generate_heatmap(x: Tensor) -> Tensor:
        ...     return torch.sigmoid(x)
    """
    tensor_spec = spec if isinstance(spec, TensorSpec) else TensorSpec(**spec)

    def decorator(func: F) -> F:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            result = func(*args, **kwargs)

            if isinstance(result, Tensor):
                errors = tensor_spec.validate(result, "output")
                if errors:
                    raise ValueError(
                        f"Output validation failed:\n  " + "\n  ".join(errors)
                    )

            return result

        return wrapper  # type: ignore[return-value]

    return decorator


# =============================================================================
# Gradient Control Decorators
# =============================================================================


def require_grad(func: F) -> F:
    """Decorator to ensure gradients are enabled during function execution.

    Raises:
        RuntimeError: If gradients are disabled.

    Example:
        >>> @require_grad
        ... def compute_gradients(model, x, target):
        ...     output = model(x)
        ...     loss = criterion(output, target)
        ...     return torch.autograd.grad(loss, x)[0]
    """

    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        if not torch.is_grad_enabled():
            raise RuntimeError(
                f"{func.__name__} requires gradients to be enabled. "
                "Ensure you're not inside a torch.no_grad() context."
            )
        return func(*args, **kwargs)

    return wrapper  # type: ignore[return-value]


def no_grad(func: F) -> F:
    """Decorator to disable gradients during function execution.

    Equivalent to wrapping the function in torch.no_grad() context.

    Example:
        >>> @no_grad
        ... def inference(model, x):
        ...     return model(x)
    """

    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        with torch.no_grad():
            return func(*args, **kwargs)

    return wrapper  # type: ignore[return-value]


def inference_mode(func: F) -> F:
    """Decorator to run function in inference mode.

    Inference mode is more efficient than no_grad for pure inference.

    Example:
        >>> @inference_mode
        ... def batch_inference(model, loader):
        ...     return [model(batch) for batch in loader]
    """

    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        with torch.inference_mode():
            return func(*args, **kwargs)

    return wrapper  # type: ignore[return-value]


# =============================================================================
# Method Caching with Weak References
# =============================================================================


class cached_method:
    """Descriptor for caching method results per instance.

    Unlike @cached, this creates a separate cache for each instance,
    preventing memory leaks from keeping instances alive.

    Args:
        maxsize: Maximum cache size per instance.
        ttl: Time-to-live for cache entries.

    Example:
        >>> class Model:
        ...     @cached_method(maxsize=100)
        ...     def forward(self, x: Tensor) -> Tensor:
        ...         return self.layers(x)
    """

    def __init__(self, maxsize: int = 128, ttl: float | None = None) -> None:
        self.maxsize = maxsize
        self.ttl = ttl
        self._caches: weakref.WeakKeyDictionary[Any, LRUCache[Any]] = (
            weakref.WeakKeyDictionary()
        )

    def __call__(self, func: Callable[..., R]) -> "cached_method":
        self.func = func
        functools.update_wrapper(self, func)
        return self

    def __get__(self, obj: Any, objtype: type | None = None) -> Callable[..., R]:
        if obj is None:
            return self  # type: ignore[return-value]

        # Get or create cache for this instance
        if obj not in self._caches:
            self._caches[obj] = LRUCache(maxsize=self.maxsize, ttl=self.ttl)

        cache = self._caches[obj]

        @functools.wraps(self.func)
        def wrapper(*args: Any, **kwargs: Any) -> R:
            key_parts = [_make_hashable(arg) for arg in args]
            key_parts.extend(
                (k, _make_hashable(v)) for k, v in sorted(kwargs.items())
            )
            key = tuple(key_parts)

            found, value = cache.get(key)
            if found:
                return value  # type: ignore[return-value]

            result = self.func(obj, *args, **kwargs)
            cache.set(key, result)
            return result

        wrapper.cache = cache  # type: ignore[attr-defined]
        wrapper.cache_clear = cache.clear  # type: ignore[attr-defined]

        return wrapper


# =============================================================================
# Retry Decorator
# =============================================================================


def retry(
    max_attempts: int = 3,
    delay: float = 0.1,
    backoff: float = 2.0,
    exceptions: tuple[type[Exception], ...] = (Exception,),
) -> Callable[[F], F]:
    """Decorator to retry function on failure.

    Args:
        max_attempts: Maximum number of attempts.
        delay: Initial delay between attempts in seconds.
        backoff: Multiplier for delay after each attempt.
        exceptions: Tuple of exceptions to catch and retry.

    Example:
        >>> @retry(max_attempts=3, delay=0.5)
        ... def unreliable_operation():
        ...     # May fail due to transient issues
        ...     pass
    """

    def decorator(func: F) -> F:
        @functools.wraps(func)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            current_delay = delay
            last_exception = None

            for attempt in range(max_attempts):
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    last_exception = e
                    if attempt < max_attempts - 1:
                        logger.warning(
                            f"{func.__name__} failed (attempt {attempt + 1}/{max_attempts}): {e}"
                        )
                        time.sleep(current_delay)
                        current_delay *= backoff

            raise last_exception  # type: ignore[misc]

        return wrapper  # type: ignore[return-value]

    return decorator


__all__ = [
    # Timing
    "timed",
    "TimingStats",
    "gpu_sync",
    # Caching
    "cached",
    "cached_method",
    "LRUCache",
    # Validation
    "validate_input",
    "validate_output",
    "TensorSpec",
    # Gradient control
    "require_grad",
    "no_grad",
    "inference_mode",
    # Retry
    "retry",
]
