import torch

from .inference_record import InferenceRecord, to_inference_record


def with_wrappers(model: torch.nn.Module) -> torch.nn.Module:
    # source_call = model.__call__
    source_forward = model.forward
    # TODO: annotate batch type here
    def wrapped_forward(batch) -> InferenceRecord:
        output = source_forward(batch.input)
        return to_inference_record(output)
    # model.__call__ = wrapped_call
    model.forward = wrapped_forward
    return model
