import torch
import torchvision
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.vgg16(pretrained=False)
        self.features = nn.Sequential(*list(self.model.features.children())[:-3])
        self.cls_conv=nn.Conv2d(512, 3, kernel_size=3, stride=1, padding=1)
        self.cls_avgpool = nn.AdaptiveAvgPool2d(1)
        self.cls_flatten=nn.Flatten(start_dim=1)
        self.cls_classifier = nn.Softmax(dim=1)

    def forward(self,x):
        x=self.features(x)
        x=self.cls_conv(x)
        x=self.cls_avgpool(x)
        x=self.cls_flatten(x)
        x=self.cls_classifier(x)

        return x

# if __name__ == '__main__':
#     img=torch.randn(1,3,224,224)
#     net=Net()
#     x=net(img)
#     label=torch.tensor([[1,0,0]],dtype=torch.float32)
#     loss_function=torch.nn.BCELoss()
#     loss=loss_function(x,label)
#     print(loss)