from torchvision import models
import torch.nn as nn

class CFE(nn.Module):
    # Common Feature Extractor
    def __init__(self):
        super().__init__()
        densenet = models.densenet201(weights=models.DenseNet201_Weights.IMAGENET1K_V1)
        for param in densenet.parameters():
            param.requires_grad = False
        self.cfe = nn.Sequential(
            densenet.features,              # output shape: [B, 1920, 7, 7]
            nn.AdaptiveAvgPool2d((1, 1)),   #           --> [B, 1920, 1, 1]
            nn.Flatten()                    #           --> [B, 1920]
        )
        self.cfe.eval()

    def forward(self, inputs):
        return self.cfe(inputs)