from typing import Callable, Iterator
from contextlib import contextmanager

import torch
from torch import nn
# from torchvision import models

from vis_datasets.wrappers.data_sample import DataSample
from .inference_record import InferenceRecord, to_inference_record


@contextmanager
def intermediate_representations(
    model: nn.Module,
    layers: list[tuple[str, nn.Module]],
) -> Iterator[nn.Module]:
    layer_reps = {}

    def activation_monitor_hook(
        layer_name: str
    ) -> Callable[[nn.Module, torch.Tensor, torch.Tensor], None]:
        def hook(
            model: nn.Module, input: torch.Tensor, output: torch.Tensor
        ) -> None:
            layer_reps[layer_name] = output.detach()
        return hook
        
    hooks = []
    for layer_name, layer in layers:
        # layer = layer_accessor(model)
        hook = layer.register_forward_hook(activation_monitor_hook(layer_name))
        hooks.append(hook)

    model.prev_forward = model.forward
    # def monitored_forward(batch: DataSample) -> InferenceRecord:
    def monitored_forward(batch: DataSample) -> InferenceRecord:
        layer_reps.clear()
        # record = to_inference_record(model(batch.x))
        # record.layer_reps = layer_reps.copy()
        output = to_inference_record(model.prev_forward(batch.input))
        return InferenceRecord(
            output=output.output,
            layer_reps=layer_reps.copy(),
        )
    model.forward = monitored_forward
    yield model

    for hook in hooks:
        hook.remove()
    model.forward = model.prev_forward
