import torch
import torch.nn as nn

from ..definitions import Model, InferenceRecord


class RecordedSequential(Model):

    def __init__(
        self,
        layers: dict[str, nn.Module],
    ) -> None:
        super().__init__()

        for layer_name, layer in layers.items():
            self.add_module(name=layer_name, module=layer)

    def forward(self, x: torch.Tensor) -> InferenceRecord:
        layer_activations: dict[str, torch.Tensor] = {"input": x}
        # layers: list[str] = ["input"]
        # activations: list[torch.Tensor] = [x]

        # for layer_name, layer in self.layers.items():
        for layer_name, layer in self.named_children():
            x = layer(x)
            layer_activations[layer_name] = x
        # layer_activations["output"] = x

        return InferenceRecord(
            output=x,
            activations=layer_activations,
        )
