import torch_fidelity
import torch.nn.functional as F
import torch
import lpips
import torch.nn as nn
import torchvision.models as models


class loss_fn:
    def __init__(self, lam=0.1):
        self.loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
        self.lam = lam
    @torch.enable_grad()
    def get_loss(self, pred, gt):
        loss_mse = torch.mean((pred - gt)**2)
        loss_lpips = self.loss_fn_vgg(pred, gt).mean()
        loss = loss_mse + loss_lpips * self.lam
        return {'loss': loss, 'loss_mse': loss_mse, 'loss_lpips': loss_lpips, 'loss_inception': torch.tensor(0.0)}
