from .big_transfer import KNOWN_MODELS
import numpy as np
import os

BIG_TRANSFER_MODEL_DIR = 'BigTransfer/'

def build_model_big_transfer(model_name, num_classes, pretrained=True):
    model = KNOWN_MODELS[model_name](head_size=num_classes, zero_head=True)
    if pretrained:
        model.load_from(np.load(os.path.join(BIG_TRANSFER_MODEL_DIR, f"{model_name}.npz")))


    return model, model_name