import timm
import torch.nn as nn
import torchvision.models as models

def build_model(args):
    if "timm" in args.backbone:
        model = timm.create_model(args.backbone.replace("timm_", ""), pretrained=False, num_classes=args.num_classes)
    else:
        model_type = getattr(models, args.backbone)
        model = model_type(pretrained=False, num_classes=args.num_classes)
    setattr(model, "losses", lambda x, y: {"cls_loss": nn.CrossEntropyLoss()(x, y)})
    return model

