import timm
import torch.nn as nn

def build_model(args):
    model = timm.create_model(args.backbone, pretrained=False, num_classes=args.num_classes)
    setattr(model, "losses", lambda x, y: {"cls_loss": nn.CrossEntropyLoss()(x, y)})
    return model
