import torch.nn as nn
import torchvision
from .resnet import *
from .resnet_imagenet import *


class BarTwins(nn.Module):
    """
    We opt for simplicity and adopt the commonly used ResNet to obtain hi = f(xi) = ResNet(xi) where hi is the output after the average pooling layer.
    """
    def __init__(self, args, data='non_imagenet'):
        super(BarTwins, self).__init__()
        self.args = args
        if data == 'imagenet':
            self.encoder = self.get_imagenet_resnet(args.resnet)
        else:
            self.encoder = self.get_resnet(args.resnet)

        self.n_features = self.encoder.feat_dim
        self.projector = nn.Sequential(nn.Linear(self.n_features, self.n_features),
                                       nn.BatchNorm1d(self.n_features),
                                       nn.ReLU(inplace=True),
                                       nn.Linear(self.n_features, args.projection_dim))

    def get_resnet(self, name):
        resnets = {
            "resnet18": resnet18(pool_len=4, data=self.args.dataset),
            "resnet34": resnet34(pool_len=4, data=self.args.dataset),
            "resnet50": resnet50(pool_len=4, data=self.args.dataset),
            "resnet101": resnet101(pool_len=4, data=self.args.dataset),
            "resnet152": resnet152(pool_len=4, data=self.args.dataset)}
        if name not in resnets.keys():
            raise KeyError(f"{name} is not a valid ResNet version")
        return resnets[name]
     
    def get_imagenet_resnet(self, name):
        resnets = {
            "resnet18": resnet18_imagenet(),
            "resnet34": resnet34_imagenet(),
            "resnet50": resnet50_imagenet(),
            "resnet101": resnet101_imagenet(),
            "resnet152": resnet152_imagenet()}
        if name not in resnets.keys():
            raise KeyError(f"{name} is not a valid ResNet version")
        return resnets[name]

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        if self.args.normalize:
            z = nn.functional.normalize(z, dim=1)
        return h, z