import os
import torch
import numpy as np
from torchvision.transforms import functional as F

from pathlib import Path

from PIL import Image

class CFDataset():
    def __init__(self, real_cf_pairs: list[tuple[Path, Path]], normalize_for_s3 = False):

        from PIL import ImageFile
        ImageFile.LOAD_TRUNCATED_IMAGES = True

        self.real_cf_pairs = real_cf_pairs
        self.normalize_for_s3 = normalize_for_s3
        for (real, cf) in self.real_cf_pairs:
            assert real.exists() and cf.exists(), "File does not exist"

    def __len__(self):
        return len(self.real_cf_pairs)

    def __getitem__(self, idx):
        real_path, cf_path = self.real_cf_pairs[idx]
        real = self.load_img(real_path)
        cf = self.load_img(cf_path)

        return real, cf

    def load_img(self, path):
        img = Image.open(os.path.join(path))
        img = np.array(img, dtype=np.uint8)
        return self.transform(img)

    def transform(self, img):
        img = img.astype(np.float32) / 255
        img = img.transpose(2, 0, 1)  # C x H x W
        img = torch.from_numpy(img).float()
        if self.normalize_for_s3:
            img = F.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], inplace=True)
        return img