import torch
import torch.nn as nn
import torchvision.models as thmodels

from .convnet import ConvNet
from .resnet import resnet18, resnet50, resnet101, resnet152


def load_model(model_name="resnet18", spec='full', pretrained=True, classes=[], input_size=224):
    def get_model(model_name="resnet18"):
        if "conv" in model_name:
            size = input_size
            nclass = len(classes)

            model = ConvNet(
                num_classes=nclass,
                net_norm="batch",
                net_act="relu",
                net_pooling="avgpooling",
                net_depth=int(model_name[-1]),
                net_width=128,
                channel=3,
                im_size=(size, size),
            )
        elif model_name == 'resnet18':
            model = resnet18(weights=None)
        elif model_name == 'resnet50':
            model = resnet50(weights=None)
        elif model_name == 'resnet101':
            model = resnet101(weights=None)
        elif model_name == 'resnet152':
            model = resnet152(weights=None)
        else:
            model = thmodels.__dict__[model_name](weights=None)

        return model

    def pruning_classifier(model=None, classes=[]):
        try:
            model_named_parameters = [name for name, x in model.named_parameters()]
            for name, x in model.named_parameters():
                if (
                    name == model_named_parameters[-1]
                    or name == model_named_parameters[-2]
                ):
                    x.data = x[classes]
        except:
            print("ERROR in changing the number of classes.")

        return model

    model = get_model(model_name)
    model = pruning_classifier(model, classes)

    if pretrained:
        if spec == 'woof':
            state_dict = torch.load(f'./pretrained_models/imagenet-woof_{model_name}.pth')
            model.load_state_dict(state_dict['model'])
        elif spec == 'im100':
            state_dict = torch.load(f'./pretrained_models/imagenet-100_{model_name}.pth')
            model.load_state_dict(state_dict['model'])
        else:
            if model_name == "efficientNet-b0":
                # Specifically, for loading the pre-trained EfficientNet model, the following modifications are made
                from torchvision.models._api import WeightsEnum
                from torch.hub import load_state_dict_from_url

                def get_state_dict(self, *args, **kwargs):
                    kwargs.pop("check_hash")
                    return load_state_dict_from_url(self.url, *args, **kwargs)

                WeightsEnum.get_state_dict = get_state_dict
            elif model_name == 'conv4':
                state_dict = torch.load('./pretrained_models/imagenet-1k_conv4.pth')
                model.load_state_dict(state_dict['model'])
            elif model_name == 'resnet18':
                model = resnet18(weights='DEFAULT')
            else:
                raise AttributeError(f'{model_name} is not supported in the pre-trained pool')

            model = pruning_classifier(model, classes)

    return model
