from spaghettini import quick_register

from torch import nn
from src.dl.models.wrappers.module_dict import ModuleDict


@quick_register
class ProverWrapper(nn.Module):
    def __init__(self, feat_extractor, proof_generator, aux_heads, proof_discretizer=None, detach_aux_heads=False):
        super().__init__()
        assert isinstance(aux_heads, ModuleDict), print(f"aux_heads must be a {ModuleDict} of PyTorch modules. ")
        self.feat_extractor = feat_extractor
        self.proof_generator = proof_generator
        self.aux_heads = aux_heads
        self.detach_aux_heads = detach_aux_heads
        self.proof_discretizer = proof_discretizer

    def forward(self, inputs, *args, **kwargs):
        model_dict = dict()

        # Get the features.
        feats = self.feat_extractor(inputs)

        # Construct the proof vectors.
        proofs, proof_generator_dict = self.proof_generator(feats)
        if self.proof_discretizer is not None:
            proofs = self.proof_discretizer(proofs)

        # Run the auxiliary heads.
        aux_outs = dict()
        for name, aux_head in self.aux_heads.items():
            feats_for_aux = feats if not self.detach_aux_heads else feats.clone().detach()
            aux_outs[name] = aux_head(feats_for_aux)

        # Update model dict.
        model_dict.update(dict(proofs=proofs, prover_feats=feats))
        model_dict.update(proof_generator_dict)

        return proofs, aux_outs, model_dict
