import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.nn.utils.weight_norm as weightNorm

class resnet18(nn.Module):
    def __init__(self, args):
        super(resnet18, self).__init__()
        model = torchvision.models.resnet18(pretrained=True)
        self.restored = False
        self.fdim = model.fc.in_features
        self.conv1 = model.conv1
        self.bn1 = model.bn1
        self.relu = model.relu
        self.maxpool = model.maxpool
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.avgpool = model.avgpool
        self.args = args

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        if self.args.fea_norm:
            x = F.normalize(x)

        return x

class Classifier(nn.Module):
    def __init__(self, args, in_dim, temp):
        super(Classifier, self).__init__()
        self.args = args
        self.in_dim = in_dim
        self.temp = temp
        if self.args.classifier_wn:
            self.classifier = weightNorm(nn.Linear(self.in_dim, self.args.num_classes))
        elif self.args.no_bias:
            self.classifier = nn.Linear(self.in_dim, self.args.num_classes, bias=False)
        else:
            self.classifier = nn.Linear(self.in_dim, self.args.num_classes)

    def forward(self, x):
        if self.args.fea_norm2:
            x = F.normalize(x)
        x = self.classifier(x) / self.temp
        return x