import torch

def load_model(arch, device=torch.device("cpu"), norm='Linf', num_classes=10):
    if arch=='Robust_Overfitting':
        from models.robust_overfitting import Robust_Overfitting_10
        model = Robust_Overfitting_10(num_classes, norm, device)
        model.load()
    return model