from models.model import WideResnet, WideResnetLarge


def get_params(dataset):
    if dataset == 'CIFAR10' or dataset == 'SVHN':
        n_classes = 10
        num_val = 5000 if dataset == 'CIFAR10' else 7320
    elif dataset == 'CIFAR100':
        n_classes = 100
        num_val = 5000
    elif dataset == 'STL10':
        n_classes = 10
        num_val = 4000
    elif dataset == 'TinyImageNet':
        n_classes = 200
        num_val = 10000

    return n_classes, num_val


def set_model(n_classes, wresnet_k, wresnet_n, stl=False, large=False):
    '''
    stl == True --> training on STL-10 and Tiny ImageNet dataset
    large == True --> WideResnetLarge is used
    '''

    if large:
        model = WideResnetLarge(n_classes=n_classes, k=wresnet_k, n=wresnet_n)
    else:
        model = WideResnet(n_classes=n_classes, k=wresnet_k, n=wresnet_n, stl=stl)

    model.train()
    model.cuda()

    return model
