import torch
import lpips
import numpy as np
from .utils import check_device


def check_image(tensor):
    assert torch.max(tensor) <= 1. and torch.min(tensor) >= -1.


class LPIPS:
    def __init__(self, base_model='alex', device='cpu') -> None:
        self.device = device
        self.loss_fn = lpips.LPIPS(net=base_model).to(device)

    @torch.no_grad()
    def score(self, samples: torch.Tensor, references: torch.Tensor):
        # ! Notice that sampels and references should be in [-1, 1]
        check_image(samples)
        check_image(references)
        # assert samples.shape[0] == references.shape[0]
        samples = check_device(samples, self.device)
        references = check_device(references, self.device)
        return self.loss_fn(samples, references).squeeze().detach().cpu()

    def on_dir():
        pass
