import torch
from torchvision.models.densenet import DenseNet

import torch.nn.functional as F


class DenseNet121(DenseNet):

    def __init__(self):
        
        super(DenseNet121, self).__init__(32, (6, 12, 24, 16), 64, 
                                    num_classes=1000)

    def forward(self, x):

        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        feature = torch.flatten(out, 1)
        logit = self.classifier(feature)
        return logit, feature
    
    # features.denseblock4.denselayer16.conv2.weight