import torch


DTYPES = {
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def no_train(func):
    def wrapper(model, *args, **kwargs):
        training = model.training
        model.eval()
        output = func(model, *args, **kwargs)
        model.train(training)
        return output
    return wrapper