"""Hook management utilities for PyTorch modules.

This module provides context managers and utilities for registering,
managing, and removing PyTorch hooks. It supports both forward and
backward hooks with automatic cleanup.

Example:
    >>> with HookManager() as hooks:
    ...     hooks.register_forward(model.layer4, capture_output=True)
    ...     output = model(x)
    ...     activations = hooks.get_activations("layer4")
    >>>
    >>> # Multi-layer hook capture
    >>> with MultiLayerHooks(model, ["layer1", "layer4"]) as hooks:
    ...     output = model(x)
    ...     for name, activation in hooks.items():
    ...         print(f"{name}: {activation.shape}")
"""

from __future__ import annotations

import weakref
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Iterator,
    Literal,
)

import torch
from torch import Tensor, nn
from torch.utils.hooks import RemovableHandle

if TYPE_CHECKING:
    pass


@dataclass
class CapturedActivation:
    """Container for captured layer activation.

    Attributes:
        output: The captured output tensor.
        input: The captured input tensor(s), if captured.
        layer_name: Name of the layer.
        grad_output: Gradient w.r.t. output, if backward hook was used.
        grad_input: Gradient w.r.t. input, if backward hook was used.
    """

    output: Tensor | None = None
    input: tuple[Tensor, ...] | None = None
    layer_name: str = ""
    grad_output: tuple[Tensor | None, ...] | None = None
    grad_input: tuple[Tensor | None, ...] | None = None

    @property
    def has_gradients(self) -> bool:
        """Whether gradients have been captured."""
        return self.grad_output is not None or self.grad_input is not None

    def clear(self) -> None:
        """Clear all captured data."""
        self.output = None
        self.input = None
        self.grad_output = None
        self.grad_input = None


@dataclass
class HookState:
    """State container for a registered hook.

    Attributes:
        handle: The RemovableHandle for the hook.
        layer: The layer the hook is registered on.
        layer_name: Identifier for the layer.
        hook_type: Type of hook ("forward", "backward", "full_backward").
        capture: The CapturedActivation for storing results.
    """

    handle: RemovableHandle
    layer: nn.Module
    layer_name: str
    hook_type: Literal["forward", "backward", "full_backward"]
    capture: CapturedActivation = field(default_factory=CapturedActivation)


class HookManager:
    """Context manager for PyTorch hook registration and cleanup.

    Provides a clean interface for registering hooks on model layers
    and automatically cleaning them up when done.

    Attributes:
        hooks: Dictionary of registered hooks by name.

    Example:
        >>> manager = HookManager()
        >>> with manager:
        ...     manager.register_forward(model.layer4, "layer4")
        ...     output = model(x)
        ...     activation = manager.get_activation("layer4")
        >>>
        >>> # Hooks are automatically removed after context exit
    """

    def __init__(self) -> None:
        self._hooks: dict[str, HookState] = {}
        self._active = False

    def __enter__(self) -> "HookManager":
        self._active = True
        return self

    def __exit__(
        self,
        exc_type: type | None,
        exc_val: Exception | None,
        exc_tb: Any,
    ) -> None:
        self.remove_all()
        self._active = False

    def register_forward(
        self,
        layer: nn.Module,
        name: str | None = None,
        capture_input: bool = False,
        capture_output: bool = True,
    ) -> str:
        """Register a forward hook on a layer.

        Args:
            layer: The layer to hook.
            name: Identifier for this hook. Auto-generated if not provided.
            capture_input: Whether to capture layer input.
            capture_output: Whether to capture layer output.

        Returns:
            The name/identifier for this hook.
        """
        if name is None:
            name = f"layer_{len(self._hooks)}"

        capture = CapturedActivation(layer_name=name)

        def hook_fn(
            module: nn.Module,
            input: tuple[Tensor, ...],
            output: Tensor,
        ) -> None:
            if capture_output:
                # Detach to avoid keeping computation graph
                if isinstance(output, Tensor):
                    capture.output = output.detach()
                elif isinstance(output, tuple):
                    capture.output = output[0].detach() if output else None
            if capture_input:
                capture.input = tuple(t.detach() for t in input if isinstance(t, Tensor))

        handle = layer.register_forward_hook(hook_fn)
        self._hooks[name] = HookState(
            handle=handle,
            layer=layer,
            layer_name=name,
            hook_type="forward",
            capture=capture,
        )

        return name

    def register_backward(
        self,
        layer: nn.Module,
        name: str | None = None,
        capture_grad_input: bool = False,
        capture_grad_output: bool = True,
    ) -> str:
        """Register a backward hook on a layer.

        Args:
            layer: The layer to hook.
            name: Identifier for this hook.
            capture_grad_input: Whether to capture input gradients.
            capture_grad_output: Whether to capture output gradients.

        Returns:
            The name/identifier for this hook.
        """
        if name is None:
            name = f"layer_{len(self._hooks)}"

        capture = CapturedActivation(layer_name=name)

        def hook_fn(
            module: nn.Module,
            grad_input: tuple[Tensor | None, ...],
            grad_output: tuple[Tensor | None, ...],
        ) -> None:
            if capture_grad_output:
                capture.grad_output = tuple(
                    g.detach() if g is not None else None for g in grad_output
                )
            if capture_grad_input:
                capture.grad_input = tuple(
                    g.detach() if g is not None else None for g in grad_input
                )

        handle = layer.register_full_backward_hook(hook_fn)
        self._hooks[name] = HookState(
            handle=handle,
            layer=layer,
            layer_name=name,
            hook_type="full_backward",
            capture=capture,
        )

        return name

    def register_forward_and_backward(
        self,
        layer: nn.Module,
        name: str | None = None,
    ) -> tuple[str, str]:
        """Register both forward and backward hooks on a layer.

        Args:
            layer: The layer to hook.
            name: Base identifier for these hooks.

        Returns:
            Tuple of (forward_name, backward_name).
        """
        if name is None:
            name = f"layer_{len(self._hooks)}"

        forward_name = self.register_forward(layer, f"{name}_fwd")
        backward_name = self.register_backward(layer, f"{name}_bwd")

        return forward_name, backward_name

    def get_activation(self, name: str) -> CapturedActivation:
        """Get captured activation for a hook.

        Args:
            name: Hook identifier.

        Returns:
            CapturedActivation containing captured data.

        Raises:
            KeyError: If hook not found.
        """
        if name not in self._hooks:
            raise KeyError(f"Hook '{name}' not found. Available: {list(self._hooks)}")
        return self._hooks[name].capture

    def get_output(self, name: str) -> Tensor | None:
        """Get captured output tensor for a hook.

        Args:
            name: Hook identifier.

        Returns:
            Output tensor or None if not captured.
        """
        return self.get_activation(name).output

    def get_grad_output(self, name: str) -> tuple[Tensor | None, ...] | None:
        """Get captured gradient output for a hook.

        Args:
            name: Hook identifier.

        Returns:
            Gradient tuple or None if not captured.
        """
        return self.get_activation(name).grad_output

    def clear_captures(self) -> None:
        """Clear all captured data but keep hooks registered."""
        for hook_state in self._hooks.values():
            hook_state.capture.clear()

    def remove(self, name: str) -> None:
        """Remove a specific hook.

        Args:
            name: Hook identifier to remove.
        """
        if name in self._hooks:
            self._hooks[name].handle.remove()
            del self._hooks[name]

    def remove_all(self) -> None:
        """Remove all registered hooks."""
        for hook_state in self._hooks.values():
            hook_state.handle.remove()
        self._hooks.clear()

    @property
    def hook_names(self) -> list[str]:
        """List of all registered hook names."""
        return list(self._hooks.keys())

    def __len__(self) -> int:
        return len(self._hooks)

    def __contains__(self, name: str) -> bool:
        return name in self._hooks


class MultiLayerHooks:
    """Context manager for capturing activations from multiple layers.

    Provides a convenient way to capture activations from multiple
    named layers in a model.

    Example:
        >>> layers = {"early": model.layer1, "late": model.layer4}
        >>> with MultiLayerHooks(layers) as hooks:
        ...     output = model(x)
        ...     early_act = hooks["early"]
        ...     late_act = hooks["late"]
    """

    def __init__(
        self,
        layers: dict[str, nn.Module] | nn.Module,
        layer_names: list[str] | None = None,
        capture_gradients: bool = False,
    ) -> None:
        """Initialize multi-layer hook manager.

        Args:
            layers: Either a dict of {name: layer} or a model with named children.
            layer_names: If layers is a model, names of layers to hook.
            capture_gradients: Whether to also capture gradients.
        """
        self._manager = HookManager()
        self._capture_gradients = capture_gradients

        # Handle model input
        if isinstance(layers, nn.Module):
            if layer_names is None:
                raise ValueError(
                    "layer_names must be provided when layers is a model"
                )
            self._layers = {name: self._get_layer(layers, name) for name in layer_names}
        else:
            self._layers = layers

        self._forward_names: dict[str, str] = {}
        self._backward_names: dict[str, str] = {}

    @staticmethod
    def _get_layer(model: nn.Module, name: str) -> nn.Module:
        """Get layer from model by name (supports dot notation)."""
        parts = name.split(".")
        layer = model
        for part in parts:
            if hasattr(layer, part):
                layer = getattr(layer, part)
            elif part.isdigit():
                layer = layer[int(part)]
            else:
                raise AttributeError(f"Module has no attribute '{part}'")
        return layer

    def __enter__(self) -> "MultiLayerHooks":
        self._manager.__enter__()

        for name, layer in self._layers.items():
            fwd_name = self._manager.register_forward(layer, name)
            self._forward_names[name] = fwd_name

            if self._capture_gradients:
                bwd_name = self._manager.register_backward(layer, f"{name}_grad")
                self._backward_names[name] = bwd_name

        return self

    def __exit__(
        self,
        exc_type: type | None,
        exc_val: Exception | None,
        exc_tb: Any,
    ) -> None:
        self._manager.__exit__(exc_type, exc_val, exc_tb)

    def __getitem__(self, name: str) -> Tensor | None:
        """Get activation for a layer by name."""
        if name not in self._forward_names:
            raise KeyError(f"Layer '{name}' not found")
        return self._manager.get_output(self._forward_names[name])

    def get_gradient(self, name: str) -> tuple[Tensor | None, ...] | None:
        """Get gradient for a layer by name."""
        if name not in self._backward_names:
            if not self._capture_gradients:
                raise ValueError("Gradient capture was not enabled")
            raise KeyError(f"Layer '{name}' not found")
        return self._manager.get_grad_output(self._backward_names[name])

    def items(self) -> Iterator[tuple[str, Tensor | None]]:
        """Iterate over (name, activation) pairs."""
        for name in self._layers:
            yield name, self[name]

    def clear(self) -> None:
        """Clear all captured activations."""
        self._manager.clear_captures()


@contextmanager
def capture_activations(
    layers: dict[str, nn.Module],
    with_gradients: bool = False,
) -> Iterator[dict[str, Tensor | None]]:
    """Context manager to capture activations from multiple layers.

    A simpler interface when you just need the activations dict.

    Args:
        layers: Dict mapping names to layer modules.
        with_gradients: Whether to also capture gradients.

    Yields:
        Dict mapping layer names to captured activations.

    Example:
        >>> layers = {"conv": model.conv1, "pool": model.pool1}
        >>> with capture_activations(layers) as acts:
        ...     output = model(x)
        >>> print(acts["conv"].shape)
    """
    activations: dict[str, Tensor | None] = {}

    with MultiLayerHooks(layers, capture_gradients=with_gradients) as hooks:
        yield activations
        # After forward pass, populate activations
        for name in layers:
            activations[name] = hooks[name]


class GradientHook:
    """Hook for capturing and optionally modifying gradients.

    Can be used to implement gradient-based attribution methods
    like Guided Backpropagation or CAM variants.

    Example:
        >>> # Guided backprop: clamp negative gradients
        >>> def guide_relu(grad):
        ...     return torch.clamp(grad, min=0)
        >>>
        >>> with GradientHook(model.relu, modifier=guide_relu):
        ...     output = model(x)
        ...     output.backward(target_grad)
    """

    def __init__(
        self,
        layer: nn.Module,
        modifier: Callable[[Tensor], Tensor] | None = None,
        capture: bool = True,
    ) -> None:
        """Initialize gradient hook.

        Args:
            layer: Layer to hook.
            modifier: Optional function to modify gradients.
            capture: Whether to store captured gradients.
        """
        self.layer = layer
        self.modifier = modifier
        self.capture = capture
        self._handle: RemovableHandle | None = None
        self._captured_grad: Tensor | None = None

    def __enter__(self) -> "GradientHook":
        def hook_fn(
            module: nn.Module,
            grad_input: tuple[Tensor | None, ...],
            grad_output: tuple[Tensor | None, ...],
        ) -> tuple[Tensor | None, ...] | None:
            if grad_output[0] is None:
                return None

            grad = grad_output[0]

            if self.capture:
                self._captured_grad = grad.detach().clone()

            if self.modifier is not None:
                modified = self.modifier(grad)
                return (modified,) + grad_output[1:]

            return None

        self._handle = self.layer.register_full_backward_hook(hook_fn)
        return self

    def __exit__(
        self,
        exc_type: type | None,
        exc_val: Exception | None,
        exc_tb: Any,
    ) -> None:
        if self._handle is not None:
            self._handle.remove()
            self._handle = None

    @property
    def grad(self) -> Tensor | None:
        """Get captured gradient."""
        return self._captured_grad


class FeatureMapHook:
    """Specialized hook for capturing feature maps in CAM methods.

    Designed specifically for GradCAM-style methods that need both
    the feature maps and their gradients.

    Example:
        >>> with FeatureMapHook(model.layer4) as hook:
        ...     output = model(x)
        ...     score = output[0, target_class]
        ...     score.backward()
        ...
        ...     features = hook.features  # [B, K, H, W]
        ...     gradients = hook.gradients  # [B, K, H, W]
    """

    def __init__(self, layer: nn.Module) -> None:
        """Initialize feature map hook.

        Args:
            layer: The layer to capture features from.
        """
        self.layer = layer
        self._features: Tensor | None = None
        self._gradients: Tensor | None = None
        self._fwd_handle: RemovableHandle | None = None
        self._bwd_handle: RemovableHandle | None = None

    def __enter__(self) -> "FeatureMapHook":
        def fwd_hook(
            module: nn.Module,
            input: tuple[Tensor, ...],
            output: Tensor,
        ) -> None:
            self._features = output.detach()

        def bwd_hook(
            module: nn.Module,
            grad_input: tuple[Tensor | None, ...],
            grad_output: tuple[Tensor | None, ...],
        ) -> None:
            if grad_output[0] is not None:
                self._gradients = grad_output[0].detach()

        self._fwd_handle = self.layer.register_forward_hook(fwd_hook)
        self._bwd_handle = self.layer.register_full_backward_hook(bwd_hook)

        return self

    def __exit__(
        self,
        exc_type: type | None,
        exc_val: Exception | None,
        exc_tb: Any,
    ) -> None:
        if self._fwd_handle is not None:
            self._fwd_handle.remove()
        if self._bwd_handle is not None:
            self._bwd_handle.remove()
        self._fwd_handle = None
        self._bwd_handle = None

    @property
    def features(self) -> Tensor | None:
        """Get captured feature maps."""
        return self._features

    @property
    def gradients(self) -> Tensor | None:
        """Get captured gradients."""
        return self._gradients

    def clear(self) -> None:
        """Clear captured data."""
        self._features = None
        self._gradients = None


__all__ = [
    # Dataclasses
    "CapturedActivation",
    "HookState",
    # Managers
    "HookManager",
    "MultiLayerHooks",
    # Context managers
    "capture_activations",
    # Specialized hooks
    "GradientHook",
    "FeatureMapHook",
]
