"""
The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch.
"""
import torch, torch.nn as nn
import pretrainedmodels as ptm





"""============================================================="""
class Network(torch.nn.Module):
    def __init__(self, opt):
        super(Network, self).__init__()

        self.pars  = opt
        self.model = ptm.__dict__['se_resnet50'](num_classes=1000, pretrained='imagenet' if not opt.not_pretrained else None)

        self.name = opt.arch

        if 'frozen' in opt.arch:
            for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
                module.eval()
                module.train = lambda _: None

        self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim)

        self.out_adjust = None


    def forward(self, x, **kwargs):
        x = self.model.features(x)
        no_avg_feat = x
        x = self.model.avg_pool(x)
        if self.model.dropout is not None:
            x = self.model.dropout(x)
        enc_out = x = x.view(x.size(0),-1)
        if 'daml' not in self.pars.loss or not self.training:
            x = self.model.last_linear(x)

            if 'normalize' in self.pars.arch:
                x = torch.nn.functional.normalize(x, dim=-1)
            if self.out_adjust and not self.train:
                x = self.out_adjust(x)

        return x, (enc_out, no_avg_feat)

