"""Base classes and decorators for metrics.

This module provides the foundation for the metrics system:
- Decorators for timing, validation, and caching
- Protocol definitions for metric interfaces
- BaseMetric abstract base class

Example:
    >>> from expected_gradcam.metrics.base import BaseMetric, timed, validate_inputs
    >>>
    >>> class MyMetric(BaseMetric):
    ...     metric_name = "my_metric"
    ...
    ...     @timed
    ...     @validate_inputs
    ...     def compute(self, alpha, I_samples) -> float:
    ...         return float((alpha ** 2).sum())
"""

from __future__ import annotations

import functools
import hashlib
import time
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    ParamSpec,
    Protocol,
    TypeVar,
    runtime_checkable,
)

import torch

from expected_gradcam.metrics.exceptions import (
    InvalidMetricInputError,
    NumericalInstabilityError,
)
from expected_gradcam.metrics.registry import MetricRegistryMeta

if TYPE_CHECKING:
    from torch import Tensor


P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


# =============================================================================
# Decorators
# =============================================================================


def timed(func: Callable[P, R]) -> Callable[P, R]:
    """Decorator to measure execution time of a function.

    Records the execution time in `func.last_execution_time` (seconds).
    Also accumulates total time in `func.total_execution_time`.

    Example:
        >>> @timed
        ... def slow_function():
        ...     time.sleep(0.1)
        ...     return 42
        >>>
        >>> result = slow_function()
        >>> print(f"Took {slow_function.last_execution_time:.3f}s")
    """

    @functools.wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        start = time.perf_counter()
        result = func(*args, **kwargs)
        elapsed = time.perf_counter() - start

        # Store timing information
        wrapper.last_execution_time = elapsed  # type: ignore
        wrapper.total_execution_time = getattr(wrapper, "total_execution_time", 0.0) + elapsed  # type: ignore
        wrapper.call_count = getattr(wrapper, "call_count", 0) + 1  # type: ignore

        return result

    # Initialize timing attributes
    wrapper.last_execution_time = 0.0  # type: ignore
    wrapper.total_execution_time = 0.0  # type: ignore
    wrapper.call_count = 0  # type: ignore

    return wrapper


def validate_inputs(func: Callable[P, R]) -> Callable[P, R]:
    """Decorator to validate inputs before computation.

    Calls `self.validate_inputs(**kwargs)` before executing the function.
    The validate_inputs method should raise appropriate exceptions
    if validation fails.

    Example:
        >>> class MyMetric(BaseMetric):
        ...     def validate_inputs(self, alpha=None, **kwargs):
        ...         if alpha is None:
        ...             raise InvalidMetricInputError("my_metric", "alpha", "Tensor", "None")
        ...
        ...     @validate_inputs
        ...     def compute(self, alpha):
        ...         return float(alpha.sum())
    """

    @functools.wraps(func)
    def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> R:
        # Call validate_inputs if it exists
        if hasattr(self, "validate_inputs"):
            self.validate_inputs(**kwargs)
        return func(self, *args, **kwargs)

    return wrapper  # type: ignore


def cached(
    ttl_seconds: float = 60.0,
    maxsize: int = 128,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
    """Decorator for caching metric results with TTL expiration.

    Caches results based on a hash of the arguments. Supports tensor
    arguments by hashing their content. Cache entries expire after
    `ttl_seconds`.

    Args:
        ttl_seconds: Time-to-live for cache entries in seconds.
        maxsize: Maximum number of entries in the cache.

    Example:
        >>> @cached(ttl_seconds=30.0)
        ... def expensive_computation(x):
        ...     return x ** 2
    """

    def decorator(func: Callable[P, R]) -> Callable[P, R]:
        cache: dict[str, tuple[R, float]] = {}

        def _hash_arg(arg: Any) -> str:
            """Hash an argument for cache key generation."""
            if isinstance(arg, torch.Tensor):
                # Hash tensor data
                data = arg.detach().cpu().numpy().tobytes()
                return hashlib.md5(data).hexdigest()
            elif isinstance(arg, (list, tuple)):
                return str([_hash_arg(a) for a in arg])
            elif isinstance(arg, dict):
                return str({k: _hash_arg(v) for k, v in sorted(arg.items())})
            else:
                return str(arg)

        @functools.wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
            # Generate cache key
            key_parts = [_hash_arg(a) for a in args]
            key_parts.extend(f"{k}={_hash_arg(v)}" for k, v in sorted(kwargs.items()))
            cache_key = "|".join(key_parts)

            now = time.time()

            # Check cache
            if cache_key in cache:
                result, timestamp = cache[cache_key]
                if now - timestamp < ttl_seconds:
                    return result

            # Compute and cache result
            result = func(*args, **kwargs)
            cache[cache_key] = (result, now)

            # Evict old entries if over maxsize
            if len(cache) > maxsize:
                # Remove oldest entries
                sorted_keys = sorted(cache.keys(), key=lambda k: cache[k][1])
                for old_key in sorted_keys[: len(cache) - maxsize]:
                    del cache[old_key]

            return result

        # Add cache management methods
        wrapper.cache_clear = cache.clear  # type: ignore
        wrapper.cache_info = lambda: {"size": len(cache), "maxsize": maxsize}  # type: ignore

        return wrapper

    return decorator


def check_finite(func: Callable[P, float]) -> Callable[P, float]:
    """Decorator to check that the result is finite (not NaN or Inf).

    Raises NumericalInstabilityError if the result is NaN or Inf.

    Example:
        >>> @check_finite
        ... def compute_ratio(a, b):
        ...     return a / b  # Raises error if b is 0
    """

    @functools.wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> float:
        result = func(*args, **kwargs)

        if not isinstance(result, (int, float)):
            return result

        import math

        if math.isnan(result) or math.isinf(result):
            # Try to get metric name from self
            metric_name = "unknown"
            if args and hasattr(args[0], "metric_name"):
                metric_name = args[0].metric_name

            raise NumericalInstabilityError(metric_name, result)

        return result

    return wrapper


def require_grad(func: Callable[P, R]) -> Callable[P, R]:
    """Decorator to ensure gradients are enabled during computation.

    Example:
        >>> @require_grad
        ... def compute_gradient_metric(x):
        ...     return x.grad.sum()
    """

    @functools.wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        with torch.enable_grad():
            return func(*args, **kwargs)

    return wrapper


def no_grad(func: Callable[P, R]) -> Callable[P, R]:
    """Decorator to disable gradients during computation.

    Example:
        >>> @no_grad
        ... def compute_metric(x):
        ...     return float(x.sum())  # No gradients needed
    """

    @functools.wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        with torch.no_grad():
            return func(*args, **kwargs)

    return wrapper


def inference_mode(func: Callable[P, R]) -> Callable[P, R]:
    """Decorator to use inference mode (faster than no_grad).

    Example:
        >>> @inference_mode
        ... def fast_metric(x):
        ...     return float(x.sum())
    """

    @functools.wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        with torch.inference_mode():
            return func(*args, **kwargs)

    return wrapper


def gpu_sync(func: Callable[P, R]) -> Callable[P, R]:
    """Decorator to synchronize CUDA before and after computation.

    Useful for accurate timing of GPU operations.

    Example:
        >>> @gpu_sync
        ... @timed
        ... def gpu_metric(x):
        ...     return float(x.cuda().sum())
    """

    @functools.wraps(func)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        result = func(*args, **kwargs)
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        return result

    return wrapper


# =============================================================================
# Protocols
# =============================================================================


@runtime_checkable
class MetricProtocol(Protocol):
    """Protocol defining the interface for all metrics.

    This protocol uses structural subtyping - any class with these
    methods and properties is considered a valid metric.

    Example:
        >>> class MyMetric:
        ...     @property
        ...     def name(self) -> str:
        ...         return "my_metric"
        ...     @property
        ...     def display_name(self) -> str:
        ...         return "My Metric"
        ...     @property
        ...     def lower_is_better(self) -> bool:
        ...         return True
        ...     def compute(self, **kwargs) -> float:
        ...         return 0.0
        ...     def validate_inputs(self, **kwargs) -> None:
        ...         pass
        >>>
        >>> isinstance(MyMetric(), MetricProtocol)
        True
    """

    @property
    def name(self) -> str:
        """Unique identifier for the metric."""
        ...

    @property
    def display_name(self) -> str:
        """Human-readable name for display."""
        ...

    @property
    def lower_is_better(self) -> bool:
        """Whether lower values indicate better performance."""
        ...

    def compute(self, **kwargs: Any) -> float:
        """Compute the metric value."""
        ...

    def validate_inputs(self, **kwargs: Any) -> None:
        """Validate inputs before computation."""
        ...


@runtime_checkable
class StreamingMetricProtocol(MetricProtocol, Protocol):
    """Protocol for metrics that can be computed incrementally.

    Streaming metrics maintain internal state and can be updated
    with batches of data, useful for real-time visualization.

    Example:
        >>> class RunningMean(StreamingMetricProtocol):
        ...     def __init__(self):
        ...         self._sum = 0.0
        ...         self._count = 0
        ...
        ...     def update(self, batch_data: dict) -> None:
        ...         self._sum += batch_data["value"]
        ...         self._count += 1
        ...
        ...     def get_current_value(self) -> float:
        ...         return self._sum / max(self._count, 1)
        ...
        ...     def reset(self) -> None:
        ...         self._sum = 0.0
        ...         self._count = 0
    """

    def update(self, batch_data: dict[str, Any]) -> None:
        """Update the metric with a batch of data."""
        ...

    def get_current_value(self) -> float:
        """Get the current metric value."""
        ...

    def reset(self) -> None:
        """Reset the metric state."""
        ...


# =============================================================================
# Base Classes
# =============================================================================


class BaseMetric(metaclass=MetricRegistryMeta):
    """Abstract base class for all metrics.

    Provides common functionality including:
    - Automatic registration via metaclass
    - Property access for metadata
    - Default validation implementation
    - Common utility methods

    Subclasses must define:
    - `metric_name`: Unique identifier for registration
    - `compute()`: The actual metric computation

    Example:
        >>> from expected_gradcam.metrics.registry import register_metric
        >>>
        >>> @register_metric("my_metric", display_name="My Metric")
        ... class MyMetric(BaseMetric):
        ...     def compute(self, alpha, I_samples, **kwargs) -> float:
        ...         return float((alpha ** 2).sum())
    """

    _abstract = True  # Skip registration for base class

    # Class attributes set by @register_metric or directly
    metric_name: str = ""
    _display_name: str = ""
    _lower_is_better: bool = True
    _streamable: bool = False
    _category: str = "general"

    @property
    def name(self) -> str:
        """Unique identifier for the metric."""
        return self.metric_name

    @property
    def display_name(self) -> str:
        """Human-readable name for display."""
        return self._display_name or self.metric_name.replace("_", " ").title()

    @property
    def lower_is_better(self) -> bool:
        """Whether lower values indicate better performance."""
        return self._lower_is_better

    @property
    def streamable(self) -> bool:
        """Whether the metric can be computed incrementally."""
        return self._streamable

    @property
    def category(self) -> str:
        """Category for grouping metrics."""
        return self._category

    def validate_inputs(self, **kwargs: Any) -> None:
        """Validate inputs before computation.

        Override this method to add input validation. Should raise
        appropriate exceptions (InvalidMetricInputError, etc.) on failure.

        Args:
            **kwargs: Input parameters to validate.
        """
        pass

    def compute(self, **kwargs: Any) -> float:
        """Compute the metric value.

        Must be overridden by subclasses.

        Args:
            **kwargs: Parameters required for computation.

        Returns:
            The computed metric value.
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement compute()"
        )

    def __repr__(self) -> str:
        """Return string representation."""
        return (
            f"{self.__class__.__name__}("
            f"name={self.metric_name!r}, "
            f"lower_is_better={self.lower_is_better})"
        )

    def __call__(self, **kwargs: Any) -> float:
        """Allow calling the metric as a function.

        Example:
            >>> metric = MyMetric()
            >>> value = metric(alpha=alpha, I_samples=I_samples)
        """
        return self.compute(**kwargs)


class StreamingMetric(BaseMetric):
    """Base class for streaming (incremental) metrics.

    Extends BaseMetric with state management for incremental updates.
    Useful for real-time visualization where metrics are computed
    progressively.

    Example:
        >>> class RunningInfidelity(StreamingMetric):
        ...     _streamable = True
        ...
        ...     def __init__(self):
        ...         super().__init__()
        ...         self._squared_errors: list[float] = []
        ...
        ...     def update(self, batch_data: dict) -> None:
        ...         error = batch_data["predicted"] - batch_data["actual"]
        ...         self._squared_errors.append(float(error ** 2))
        ...
        ...     def get_current_value(self) -> float:
        ...         if not self._squared_errors:
        ...             return 0.0
        ...         return sum(self._squared_errors) / len(self._squared_errors)
        ...
        ...     def reset(self) -> None:
        ...         self._squared_errors.clear()
    """

    _abstract = True
    _streamable = True

    def __init__(self) -> None:
        """Initialize the streaming metric."""
        pass

    def update(self, batch_data: dict[str, Any]) -> None:
        """Update the metric with a batch of data.

        Args:
            batch_data: Dictionary containing batch data for update.
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement update()"
        )

    def get_current_value(self) -> float:
        """Get the current metric value based on accumulated data.

        Returns:
            Current metric value.
        """
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement get_current_value()"
        )

    def reset(self) -> None:
        """Reset the metric state for a new computation."""
        raise NotImplementedError(
            f"{self.__class__.__name__} must implement reset()"
        )

    def compute(self, **kwargs: Any) -> float:
        """Compute by returning the current value.

        For streaming metrics, compute() returns the current accumulated value.
        """
        return self.get_current_value()
