from piqa import PSNR, MS_SSIM, LPIPS

class ImageMetrics:
    def __init__(self, device):
        self.iqa = {
            'psnr': PSNR().to(device), 
            'ms_ssim': MS_SSIM().to(device), 
            'lpips': LPIPS().to(device)
        }
    
    def report(self, x, y):
        assert x.max() <= 1 and 0 <= x.min(), 'image x is not in the correct range of [0, 1]'
        assert y.max() <= 1 and 0 <= y.min(), 'image y is not in the correct range of [0, 1]'

        out = {name: metric(x, y).cpu() for name, metric in self.iqa.items()}
        return out