from spaghettini import quick_register

import torch
from torch import nn
from src.dl.models.wrappers.module_dict import ModuleDict


@quick_register
class VerifierWrapper(nn.Module):
    def __init__(self, proof_input_processor, classifier, aux_heads=None, detach_aux_heads=False):
        super().__init__()
        assert isinstance(aux_heads, ModuleDict) or aux_heads is None, \
            print(f"aux_heads must be a {ModuleDict} of PyTorch modules. ")
        self.proof_input_processor = proof_input_processor
        self.classifier = classifier
        self.aux_heads = aux_heads
        self.detach_aux_heads = detach_aux_heads

    def forward(self, inputs, proofs_list, *args, **kwargs):
        proofs = proofs_list[0] if isinstance(proofs_list, list) or isinstance(proofs_list, tuple) else proofs_list
        model_dict = dict()

        feats, feats_dict = self.proof_input_processor(inputs, proofs)
        logits = self.classifier(feats)
        aux_outs = dict()
        if self.aux_heads is not None:
            for name, aux_head in self.aux_heads.items():
                feats_for_aux = feats if not self.detach_aux_heads else feats.clone().detach()
                aux_outs[name] = aux_head(feats_for_aux)

        # Update model dict.
        model_dict.update(dict(verifier_logits=logits))
        model_dict.update(feats_dict)

        return logits, aux_outs, model_dict
