from collections import OrderedDict
from models.resnet import resnet18 as _resnet18
from models.resnet import resnet50 as _resnet50
from models.mobilenetv2 import mobilenetv2 as _mobilenetv2

import torch
from torch.hub import load_state_dict_from_url

dependencies = ['torch']

def resnet18(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _resnet18(**kwargs)
    if pretrained:
        load_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 
        checkpoint = load_state_dict_from_url(url=load_url, map_location='cpu', progress=True)
        model.load_state_dict(checkpoint)
    return model


def resnet50(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _resnet50(**kwargs)
    if pretrained:
        load_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 
        checkpoint = load_state_dict_from_url(url=load_url, map_location='cpu', progress=True)
        model.load_state_dict(checkpoint)
    return model


def mobilenetv2(pretrained=False, **kwargs):
    # Call the model, load pretrained weights
    model = _mobilenetv2(**kwargs)
    if pretrained:
        load_url = 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth' 
        checkpoint = load_state_dict_from_url(url=load_url, map_location='cpu', progress=True)
        for key in list(checkpoint.keys()):
            if 'features.1.' in key:
                if 'conv.0.0' in key:
                    checkpoint[key.replace('conv.0.0', 'conv.0')] = checkpoint[key]
                    del checkpoint[key]
                elif 'conv.1' in key:
                    checkpoint[key.replace('conv.1', 'conv.3')] = checkpoint[key]
                    del checkpoint[key]
                elif 'conv.2' in key:
                    checkpoint[key.replace('conv.2', 'conv.4')] = checkpoint[key]
                    del checkpoint[key]
            else: 
                if 'conv.0.0' in key:
                    checkpoint[key.replace('conv.0.0', 'conv.0')] = checkpoint[key]
                    del checkpoint[key]
                elif 'conv.0.1' in key:
                    checkpoint[key.replace('conv.0.1', 'conv.1')] = checkpoint[key]
                    del checkpoint[key]
                elif 'conv.1.1' in key:
                    checkpoint[key.replace('conv.1.1', 'conv.4')] = checkpoint[key]
                    del checkpoint[key]
                elif 'conv.2' in key:
                    checkpoint[key.replace('conv.2', 'conv.6')] = checkpoint[key]
                    del checkpoint[key]
                elif 'conv.3' in key:
                    checkpoint[key.replace('conv.3', 'conv.7')] = checkpoint[key]
                    del checkpoint[key]
        for key in list(checkpoint.keys()):
            if 'features.1.' in key:
                if 'conv.0.1' in key:
                    checkpoint[key.replace('conv.0.1', 'conv.1')] = checkpoint[key]
                    del checkpoint[key]
            else:
                if 'conv.1.0' in key:
                    checkpoint[key.replace('conv.1.0', 'conv.3')] = checkpoint[key]
                    del checkpoint[key]
        model.load_state_dict(checkpoint) 
    return model
