import torch.nn as nn
import torchvision
from .resnet import *

class Pretrain(nn.Module):
    def __init__(self, args, nclass):
        super(Pretrain, self).__init__()
        self.args = args
        self.encoder = self.get_resnet(args.resnet)
        self.n_features = self.encoder.feat_dim
        self.nclass = nclass
        self.predictor = nn.Linear(self.n_features, self.nclass, bias=True)

    def get_resnet(self, name):
        resnets = {
            "resnet18": resnet18(pool_len=4),
            "resnet34": resnet34(pool_len=4),
            "resnet50": resnet50(pool_len=4),
            "resnet101": resnet101(pool_len=4),
            "resnet152": resnet152(pool_len=4)}
        if name not in resnets.keys():
            raise KeyError(f"{name} is not a valid ResNet version")
        return resnets[name]
     
    def forward(self, x):
        z = self.encoder(x)
        y = self.predictor(z)
        return z, y