import torch
import torch.nn as nn
from torch.nn.utils.weight_norm import WeightNorm
from torchvision import models


def init_weights(m):
    classname = m.__class__.__name__
    if classname.find("Conv2d") != -1 or classname.find("ConvTranspose2d") != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find("Linear") != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)


resnet_dict = {
    "ResNet18": models.resnet18,
    "ResNet34": models.resnet34,
    "ResNet50": models.resnet50,
    "ResNet101": models.resnet101,
    "ResNet152": models.resnet152,
}


class distLinear(nn.Module):
    def __init__(self, indim, outdim, scale_factor=10):
        super(distLinear, self).__init__()
        self.indim = indim
        self.outdim = outdim
        self.L = nn.Linear(indim, outdim, bias=False)
        nn.init.xavier_normal_(self.L.weight)
        self.class_wise_learnable_norm = True  # See the issue#4&8 in the github
        if self.class_wise_learnable_norm:
            WeightNorm.apply(self.L, "weight", dim=0)  # split the weight update component to direction and norm

        self.scale_factor = (
            scale_factor  # in omniglot, a larger scale factor is required to handle >1000 output classes.
        )

    def forward(self, x_normalized):
        if not self.class_wise_learnable_norm:
            L_norm = torch.norm(self.L.weight.data, p=2, dim=1).unsqueeze(1).expand_as(self.L.weight.data)
            self.L.weight.data = self.L.weight.data.div(L_norm + 0.00001)
        cos_dist = self.L(
            x_normalized
        )  # matrix product by forward function, but when using WeightNorm, this also multiply the cosine distance by a class-wise learnable norm, see the issue#4&8 in the github
        scores = self.scale_factor * (cos_dist)

        return scores


class ResNetFc(nn.Module):
    def __init__(
        self, resnet_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000, cos_dist=False
    ):
        super(ResNetFc, self).__init__()
        model_resnet = resnet_dict[resnet_name](pretrained=True)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.feature_layers = nn.Sequential(
            self.conv1,
            self.bn1,
            self.relu,
            self.maxpool,
            self.layer1,
            self.layer2,
            self.layer3,
            self.layer4,
            self.avgpool,
        )

        self.use_bottleneck = use_bottleneck
        self.new_cls = new_cls
        self.cos_dist = cos_dist
        self.class_num = class_num
        if new_cls:
            if self.use_bottleneck:
                self.bottleneck = nn.Linear(model_resnet.fc.in_features, bottleneck_dim)
                self.bottleneck.apply(init_weights)
                if cos_dist:
                    self.fc = distLinear(bottleneck_dim, class_num, scale_factor=cos_dist)
                else:
                    self.fc = nn.Linear(bottleneck_dim, class_num)
                    self.fc.apply(init_weights)
                self.__in_features = bottleneck_dim
            else:
                if cos_dist:
                    self.fc = distLinear(bottleneck_dim, class_num, scale_factor=cos_dist)
                else:
                    self.fc = nn.Linear(model_resnet.fc.in_features, class_num)
                    self.fc.apply(init_weights)
                self.__in_features = model_resnet.fc.in_features
        else:
            self.fc = model_resnet.fc
            self.__in_features = model_resnet.fc.in_features

    def forward(self, x):
        x = self.feature_layers(x)
        x = x.view(x.size(0), -1)
        if self.use_bottleneck and self.new_cls:
            x = self.bottleneck(x)
        if self.cos_dist:
            x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x)
            x = x.div(x_norm + 0.00001)
        y = self.fc(x)
        return x, y

    def output_num(self):
        return self.__in_features

    def get_parameters(self):
        if self.new_cls:
            if self.use_bottleneck:
                parameter_list = [
                    {"params": self.feature_layers.parameters(), "lr_mult": 1, "decay_mult": 2},
                    {"params": self.bottleneck.parameters(), "lr_mult": 10, "decay_mult": 2},
                    {"params": self.fc.parameters(), "lr_mult": 10, "decay_mult": 2},
                ]
            else:
                parameter_list = [
                    {"params": self.feature_layers.parameters(), "lr_mult": 1, "decay_mult": 2},
                    {"params": self.fc.parameters(), "lr_mult": 10, "decay_mult": 2},
                ]
        else:
            parameter_list = [{"params": self.parameters(), "lr_mult": 1, "decay_mult": 2}]
        return parameter_list


vgg_dict = {
    "VGG11": models.vgg11,
    "VGG13": models.vgg13,
    "VGG16": models.vgg16,
    "VGG19": models.vgg19,
    "VGG11BN": models.vgg11_bn,
    "VGG13BN": models.vgg13_bn,
    "VGG16BN": models.vgg16_bn,
    "VGG19BN": models.vgg19_bn,
}


class VGGFc(nn.Module):
    def __init__(self, vgg_name, use_bottleneck=True, bottleneck_dim=256, new_cls=False, class_num=1000):
        super(VGGFc, self).__init__()
        model_vgg = vgg_dict[vgg_name](pretrained=True)
        self.features = model_vgg.features
        self.classifier = nn.Sequential()
        for i in range(6):
            self.classifier.add_module("classifier" + str(i), model_vgg.classifier[i])
        self.feature_layers = nn.Sequential(self.features, self.classifier)

        self.use_bottleneck = use_bottleneck
        self.new_cls = new_cls
        if new_cls:
            if self.use_bottleneck:
                self.bottleneck = nn.Linear(4096, bottleneck_dim)
                self.fc = nn.Linear(bottleneck_dim, class_num)
                self.bottleneck.apply(init_weights)
                self.fc.apply(init_weights)
                self.__in_features = bottleneck_dim
            else:
                self.fc = nn.Linear(4096, class_num)
                self.fc.apply(init_weights)
                self.__in_features = 4096
        else:
            self.fc = model_vgg.classifier[6]
            self.__in_features = 4096

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        if self.use_bottleneck and self.new_cls:
            x = self.bottleneck(x)
        y = self.fc(x)
        return x, y

    def output_num(self):
        return self.__in_features

    def get_parameters(self):
        if self.new_cls:
            if self.use_bottleneck:
                parameter_list = [
                    {"params": self.features.parameters(), "lr_mult": 1, "decay_mult": 2},
                    {"params": self.classifier.parameters(), "lr_mult": 1, "decay_mult": 2},
                    {"params": self.bottleneck.parameters(), "lr_mult": 10, "decay_mult": 2},
                    {"params": self.fc.parameters(), "lr_mult": 10, "decay_mult": 2},
                ]
            else:
                parameter_list = [
                    {"params": self.feature_layers.parameters(), "lr_mult": 1, "decay_mult": 2},
                    {"params": self.classifier.parameters(), "lr_mult": 1, "decay_mult": 2},
                    {"params": self.fc.parameters(), "lr_mult": 10, "decay_mult": 2},
                ]
        else:
            parameter_list = [{"params": self.parameters(), "lr_mult": 1, "decay_mult": 2}]
        return parameter_list