import torch.nn as nn

# from .densenet import densenet121, densenet161, densenet169, densenet201
from .efficientnet_pytorch.model import EfficientNet
from .inception_v3 import inception_v3
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from .resnext import resnext101_32x4d, resnext101_64x4d


class BaseModel(nn.Module):
    def __init__(self, last_conv_stride=1, basenet='Resnet50'):
        super(BaseModel, self).__init__()
        self.basenet = basenet
        if basenet == 'Resnet18':
            self.base = resnet18(
                pretrained=False, last_conv_stride=last_conv_stride)
        elif basenet == 'Resnet34':
            self.base = resnet34(
                pretrained=True, last_conv_stride=last_conv_stride)
        elif basenet == 'Resnet50':
            self.base = resnet50(
                pretrained=True, last_conv_stride=last_conv_stride)
        elif basenet == 'Resnet152':
            self.base = resnet152(
                pretrained=True, last_conv_stride=last_conv_stride)
        elif basenet == 'Resnet101':
            self.base = resnet101(
                pretrained=True, last_conv_stride=last_conv_stride)
        elif basenet == 'resnext101_32x4d':
            self.base = resnext101_32x4d(
                pretrained=True, last_conv_stride=last_conv_stride)
        elif basenet == 'resnext101_64x4d':
            self.base = resnext101_64x4d(
                pretrained=True, last_conv_stride=last_conv_stride)
        elif basenet.startswith('efficient'):
            self.base = EfficientNet.from_pretrained_no_fc(basenet)
        else:
            print("Unknown Base Network")

        # self.base = inception_v3(pretrained=True)
        if self.basenet in ['Resnet50', 'Resnet152', 'Resnet101', 'resnext101_32x4d', 'resnext101_64x4d']:
            self.feature_in_dim = 2048
        elif self.basenet in ['Resnet18', 'Resnet34']:
            self.feature_in_dim = 512
        elif basenet.startswith('efficient'):
            self.feature_in_dim = self.base.out_channels
