"""Compute and store the activation values of a neural network.

The class GetHookVals stores the activations for a neural network by using 
Pytorch's forward hook method."""

from __future__ import annotations

from typing import Callable, TYPE_CHECKING

if TYPE_CHECKING:
    from torch import Module, Tensor

class GetHookVals:
    """Compute and store the activation values of a neural network."""
    def __init__(self) -> None:
        self.activation = {}

    def getActivation(
            self,
            name: str,
    ) -> Callable[[str], Callable[[Module, Tensor, Tensor], None]]:
    # the hook signature
        def hook(
                model: Module,
                input: Tensor,
                output: Tensor
        ) -> None:
            self.activation[name] = output.detach()
        return hook
    
    