import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset
from torchvision.io import read_image

class CelebAHQDataset(Dataset):

    def __init__(self, path_data: str, path_metadata: str, n_samples: int):
        super().__init__()

        self.paths = self.get_paths(path_data)
        self.metadata = self.get_metadata(path_metadata)
        self.length = min(len(self.paths), n_samples)

    def get_paths(self, path):
        paths = list(Path(path).rglob('*.png'))
        return sorted(paths, key = lambda x: int(x.parts[-2]))
    
    def get_metadata(self, path):
        return pd.read_csv(path, index_col = 0)
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        # NOTE: this requires refactoring to fit current loop
        path = self.paths[index]
        img = read_image(str(path)) / 255
        return img, index