import torch

from .protocols import FeatureExtractorDelegate


class LPIPSVGG(torch.nn.Module):
    def __init__(self, pretrained_vgg: FeatureExtractorDelegate, requires_grad: bool = False):
        super().__init__()
        assert isinstance(pretrained_vgg.features, torch.nn.Sequential), "Feature modules must be a valid `torch.nn.Sequential`."
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), pretrained_vgg.features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), pretrained_vgg.features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), pretrained_vgg.features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), pretrained_vgg.features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), pretrained_vgg.features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X: torch.Tensor) -> tuple[torch.Tensor, ...]:
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        return h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3

