from torchvision import models
import torch
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.resnet import ResNet,BasicBlock,Bottleneck
from torchvision.models.mobilenet import MobileNetV2

'''
Models are offical Pytorch models
'''


class MobileNetV2_I(MobileNetV2):
    def __init__(self,
                 num_classes=1000,
                 width_mult=1.0,
                 inverted_residual_setting=None,
                 round_nearest=8,
                 block=None,
                 norm_layer=None):
        super(MobileNetV2_I, self).__init__(num_classes,
                 width_mult,
                 inverted_residual_setting,
                 round_nearest,
                 block,
                 norm_layer)
        self.model_type = 'mobilenetv2_I'




def _resnet_I(arch, block, layers,type,inital_macs, pretrained, progress, **kwargs):
    model = ResNet_I(block, layers,type,inital_macs, **kwargs)
    return model

class ResNet_I(ResNet):
    def __init__(self,block, layers,type,inital_macs=None, **kwargs):
        super(ResNet_I, self).__init__(block, layers, **kwargs)
        if type is None:
            self.model_type = 'resnet_I'
        else:
            self.model_type = type
        if inital_macs is None:
            self.macs_forward = None
        else:
            self.macs_forward = inital_macs





def resnet50_I(load_pretrained=False,pretrained_model_path=None, progress=True, **kwargs):
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    inital_macs = 3850000000
    model =  _resnet_I('resnet50', Bottleneck, [3, 4, 6, 3],'resnet50_I',inital_macs, load_pretrained, progress,
                   **kwargs)
    if load_pretrained:
        print('Loading Pretrained Model:')
        print(pretrained_model_path)
        state_dict = torch.load(pretrained_model_path)
        model.load_state_dict(state_dict)

    return model


def resnet18_I(load_pretrained=False,pretrained_model_path=None, progress=True, **kwargs):
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    inital_macs = 1814073344 #1820000000
    model = _resnet_I('resnet18', BasicBlock, [2, 2, 2, 2],'resnet18_I',inital_macs, load_pretrained, progress,
                   **kwargs)

    if load_pretrained:
        print('Loading Pretrained Model:')
        print(pretrained_model_path)
        state_dict = torch.load(pretrained_model_path)
        model.load_state_dict(state_dict)
    return model

def mobilenet_v2_I(load_pretrained=False,pretrained_model_path=None,progress=True, **kwargs):
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = MobileNetV2_I(**kwargs)
    if load_pretrained:
        print('Loading Pretrained Model:')
        print(pretrained_model_path)
        state_dict = torch.load(pretrained_model_path)
        model.load_state_dict(state_dict)
    return model

def resnet50(**kwargs):
    model = resnet50_I(**kwargs)
    return model


def resnet18(**kwargs):
    model = resnet18_I(**kwargs)
    return model

def mobilenetv2(**kwargs):
    model = mobilenet_v2_I(**kwargs)
    return model