import torch
from torchvision import transforms
from torchmetrics.image.kid import KernelInceptionDistance

@torch.no_grad()
def compute_kid(images, ref_images, batch_size: int, device: str):
    # len ref_images = 1000, target = -1
    kid = KernelInceptionDistance(subset_size=min(len(images), 50), normalize=True).to(device)
    
    def image_proc(image):
        tf = transforms.Compose([transforms.Resize(299), transforms.ToTensor()])
        return torch.stack([tf(img) for img in image], dim=0).to(device)

    for i in range(0, len(ref_images), batch_size):
        ref_batch = ref_images[i : i + batch_size]
        ref_batch = image_proc(ref_batch)
        kid.update(ref_batch, real=True)
    
    for i in range(0, len(images), batch_size):
        batch = images[i : i + batch_size]
        batch = image_proc(batch)
        kid.update(batch, real=False)
    
    score = kid.compute()
    return score[0].item()