from spaghettini import quick_register

from torch import nn


@quick_register
class ProbeWrapper(nn.Module):
    def __init__(self, probe_models_dict):
        super().__init__()
        self.probe_models_dict = probe_models_dict

    def forward(self, proofs):
        probe_outs_dict = dict()
        for name, probe_model in self.probe_models_dict.items():
            probe_out = probe_model(proofs)
            probe_outs_dict[name] = probe_out

        return probe_outs_dict
