import torchvision.models as models
from torchvision.models.utils import load_state_dict_from_url
import torch
import torch.nn as nn
import torch.nn.functional as F

model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}

ALEXNET_IMAGENET_64_PRETRAINED_PATH = './data/alexnet_best.pt'

class AlexNet(nn.Module):

    def __init__(self, embedding_size, is_norm):
        super(AlexNet, self).__init__()
        original_model = models.alexnet(pretrained=True)
        self.features = original_model.features
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.is_norm = is_norm
        self.embedding_size =embedding_size

        #self.linear_projector = nn.Linear(256 * 6 * 6, embedding_size)
        cl1 = nn.Linear(256 * 6 * 6, 4096)
        cl1.weight = original_model.classifier[1].weight
        cl1.bias = original_model.classifier[1].bias
        cl2 = nn.Linear(4096, 4096)
        cl2.weight = original_model.classifier[4].weight
        cl2.bias = original_model.classifier[4].bias

        self.linear_projector = nn.Sequential(
            nn.Dropout(),
            cl1,
            nn.ReLU(inplace=True),
            nn.Dropout(),
            cl2,
            nn.ReLU(inplace=True),
            nn.Linear(4096, embedding_size),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.linear_projector(x)
        if self.is_norm:
            x = F.normalize(x, dim=1)
        return x


def alexnet(pretrained=False, progress=True, dataset='cifar10', **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = AlexNet(**kwargs)
    embedding_size = model.embedding_size
    if pretrained:
        if dataset == 'nus_wide':
            checkpoint = torch.load(ALEXNET_IMAGENET_64_PRETRAINED_PATH)
            model.linear_projector[6] = nn.Linear(4096, 200)

            renamed_checkpoint = {}
            for key in checkpoint.keys():
                if 'classifier' in key:
                    new_key = 'linear_projector' + key[10:]
                    renamed_checkpoint[new_key] = checkpoint[key]
                else:
                    renamed_checkpoint[key] = checkpoint[key]

            print(model.load_state_dict(renamed_checkpoint, strict=False))
            model.linear_projector[6] = nn.Linear(4096, embedding_size)
    return model

