import os
import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image


class CelebAHQ256(Dataset):
    def __init__(self, image_folder, train=True):
        self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

        if train:
            self.image_paths = glob.glob(os.path.join(f'{image_folder}/train/', "*.png"))  # Get all images
            self.cache_file = "./dataset/CelebAHQ256_train.pt"
        else:
            self.image_paths = glob.glob(os.path.join(f'{image_folder}/valid/', "*.png"))
            self.cache_file = "./dataset/CelebAHQ256_valid.pt"

        if os.path.exists(self.cache_file):
            print(f"Loading dataset from {self.cache_file}...", flush=True)
            self.images = torch.load(self.cache_file)
        else:
            print(f"Processing images from {image_folder} and caching to {self.cache_file}...", flush=True)
            self._create_cache()  # Process and save images

    def _create_cache(self):
        """Loads images, applies transforms, and saves them as a tensor."""
        images = []

        for path in self.image_paths:
            image = Image.open(path).convert("RGB")
            image = self.transform(image)  # Apply transforms
            images.append(image)

        self.images = torch.stack(images)  # Convert list of tensors into a single tensor
        torch.save(self.images, self.cache_file)  # Save tensor to .pt file

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], torch.tensor(0, dtype=torch.int)

