import torch.nn as nn
import torchvision.models as models

class VGGFeatureExtractor(nn.Module):
    def __init__(self):
        super(VGGFeatureExtractor, self).__init__()
        # self.conv1x1 = nn.Conv2d(input_channels, 3, kernel_size=1)  # 映射到 3 通道
        vgg = models.vgg19(pretrained=True).features
        self.slice1 = nn.Sequential(*[vgg[x] for x in range(2)])  # 第一个卷积层输出
        self.slice2 = nn.Sequential(*[vgg[x] for x in range(2, 7)])  # 第二个卷积层输出
        self.slice3 = nn.Sequential(*[vgg[x] for x in range(7, 12)])  # 第三个卷积层输出
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        # x = self.conv1x1(x)  # 将 32 通道转换为 3 通道
        h1 = self.slice1(x)  # 输出第一个特征图
        h2 = self.slice2(h1)  # 输出第二个特征图
        h3 = self.slice3(h2)  # 输出第三个特征图
        return [h1, h2, h3]

