import torch
from models.resnet import resnet18 as _resnet18
from models.resnet import resnet50 as _resnet50
from models.mobilenetv2 import mobilenetv2 as _mobilenetv2
from models.mnasnet import mnasnet as _mnasnet
from models.regnet import regnetx_600m as _regnetx_600m
from models.regnet import regnetx_3200m as _regnetx_3200m

dependencies = ['torch']

checkpoints = {
    'resnet18':'./pretrained/resnet18_imagenet.pth.tar',
    'resnet50':'./pretrained/resnet50_imagenet.pth.tar',
    'mobilenetv2':'./pretrained/mobilenetv2.pth.tar',
    'regnetx_600m':'./pretrained/regnet_600m.pth.tar',
    'regnetx_3200m':'./pretrained/regnet_3200m.pth.tar',
    'mnasnet':'./pretrained/mnasnet.pth.tar',
}


def resnet18(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _resnet18(**kwargs)
    if pretrained:
        checkpoint = torch.load(checkpoints['resnet18'])
        model.load_state_dict(checkpoint)
    return model


def resnet50(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _resnet50(**kwargs)
    if pretrained:
        checkpoint = torch.load(checkpoints['resnet50'])
        model.load_state_dict(checkpoint)
    return model


def mobilenetv2(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _mobilenetv2(**kwargs)
    if pretrained:
        checkpoint = torch.load(checkpoints['mobilenetv2'])
        model.load_state_dict(checkpoint)
    return model


def regnetx_600m(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _regnetx_600m(**kwargs)
    if pretrained:
        checkpoint = torch.load(checkpoints['regnetx_600m'])
        model.load_state_dict(checkpoint)
    return model


def regnetx_3200m(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _regnetx_3200m(**kwargs)
    if pretrained:
        checkpoint = torch.load(checkpoints['regnetx_3200m'])
        model.load_state_dict(checkpoint)
    return model


def mnasnet(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _mnasnet(**kwargs)
    if pretrained:
        checkpoint = torch.load(checkpoints['mnasnet'])
        model.load_state_dict(checkpoint)
    return model
