import torch
from PIL import Image
from .utils import normalize_tensor, check_device
from taming.modules.losses.lpips import OCR_CRAFT_LPIPS

class OCRLPIPS:
    def __init__(self, device="cpu") -> None:
        self.ocr_craft_lpips = OCR_CRAFT_LPIPS().eval()
        self.ocr_craft_lpips = self.ocr_craft_lpips.to(device)
    
    @torch.no_grad()
    def score(self, samples: torch.Tensor, references: torch.Tensor):
        # samples: B, C, H, W
        # references: B, C, H, W
        B = samples.shape[0]
        samples = normalize_tensor(samples)
        references = normalize_tensor(references)
        if references.shape[0] == 1:
            references = references.repeat(B, 1, 1, 1)
        
        samples = check_device(samples, self.device)
        references = check_device(references, self.device)
        ocr_sim = self.ocr_craft_lpips(samples, references)
        return ocr_sim

    def on_dir():
        pass