import torch
import lpips


class LPIPSLoss(torch.nn.Module):
    def __init__(self, net: str, size: int = None):
        super().__init__()
        self.lpips: lpips.LPIPS = lpips.LPIPS(net=net)
        self.size: int = size

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        if self.size is not None:
            x: torch.Tensor = torch.nn.functional.interpolate(x, mode='bilinear', size=self.size)
            y: torch.Tensor = torch.nn.functional.interpolate(y, mode='bilinear', size=self.size)
        return self.lpips(x, y)
