import numpy as np
import torch
import torchvision.transforms as TF
import os, glob
import lpips
from PIL import Image


class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, real_files, fake_files, transforms=None):
        self.real_files = real_files
        self.fake_files = fake_files
        self.transforms = transforms

    def __len__(self):
        return len(self.real_files)

    def __getitem__(self, i):
        rp, fp = self.real_files[i], self.fake_files[i]
        real_img = np.array(Image.open(rp).convert('RGB')).astype('float32')
        real_img = (real_img - 127.5) / 127.5
        fake_img = np.array(Image.open(fp).convert('RGB')).astype('float32')
        fake_img = (fake_img - 127.5) / 127.5
        if self.transforms is not None:
            real_img = self.transforms(real_img)
            fake_img = self.transforms(fake_img)
        return real_img, fake_img
    

def cal_lpips_given_paths(dirs, batch_size, device, num_workers=6):
    real_dir, fake_dir = dirs
    total_lpips = 0
    lpips_fn = lpips.LPIPS(net='alex').to(device)
    real_files = glob.glob(os.path.join(real_dir, '*.png'))
    fake_files = glob.glob(os.path.join(fake_dir, '*.png'))

    dataset = ImagePathDataset(real_files, fake_files, transforms=TF.ToTensor())
    dataloader = torch.utils.data.DataLoader(dataset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                drop_last=False,
                                                num_workers=num_workers)

    for batch in dataloader:
        value = lpips_fn.forward(batch[0].to(device), batch[1].to(device))
        total_lpips += torch.sum(value).item()

    avg_lpips = total_lpips / len(dataset)
    return avg_lpips

