from common_imports import torch, nn
from torchvision.models import regnet_y_16gf, RegNet_Y_16GF_Weights

# The RegNet architecture, it is created for the Imagenet-1k by default.
class RegNet(nn.Module):
    def __init__(self, model_name="regnet-y-16gf-swag-e2e-v1"):
        super(RegNet, self).__init__()

        if model_name == "regnet-y-16gf-swag-e2e-v1":
            self.encoder = regnet_y_16gf(weights=RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1)

        self.fc = self.encoder.fc
        self.encoder.fc = nn.Identity()

        # Self-denied parameters for activation levels registration
        self.regist_actLevel = False
        self.nb_hidden_layers = 1

    def activate_registration(self):
        self.regist_actLevel = True
        
    def deactivate_registration(self):
        self.regist_actLevel = False

    def forward(self, x):
        if self.regist_actLevel:
            actLevel = []
            x = self.encoder(x)
            actLevel.append(x)
            x = self.fc(x)
            return x, actLevel
        else:
            x = self.encoder(x)
            x = self.fc(x)
            return x