from models.metamodel.conv4 import MetaConvModel
from models.metamodel.resnet12 import ResNet12
from models.metamodel.resnet12_barlow import ResNet12_selfsup
from models.metamodel.resnet18 import meta_resnet18
from models.metamodel.resnet18_barlow import meta_resnet18_selfsup

def ModelConvMiniImagenet(out_features, hidden_size=64, **kwargs):
    return MetaConvModel(3, out_features, hidden_size=hidden_size, feature_size=5 * 5 * hidden_size, **kwargs)


def get_model(P, modelstr):
    if modelstr == 'conv4':
        model = ModelConvMiniImagenet(P.num_ways)
            
    elif modelstr == 'resnet12':
        if P.barlow:
            model = ResNet12_selfsup(P.num_ways)
        else:
            model = ResNet12(P.num_ways)
                
    elif modelstr == 'r2d2':
        NotImplementedError()

    elif modelstr == 'resnet18':
        if P.barlow:
            model = meta_resnet18_selfsup(P.num_ways)
        else:
            model = meta_resnet18(P.num_ways)
    elif modelstr == 'conv_pose':
        model = ConvRegressor(dataset=P.dataset)
    else:
        raise NotImplementedError()

    if P.mode == 'anil':
        model.anil = True

    return model
