from pathlib import Path

from datasets import load_dataset
from torch.utils.data import Dataset


class ImageNetDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        root = str(Path(data_path).parent)
        split = Path(data_path).stem
        dataset = load_dataset(root, split=split)
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        row = self.dataset[idx]
        img = row["image"]
        key = row["label"]
        if self.transform:
            img = self.transform(img)
        return img, self.labels[key]


dataset = ImageNetDataset("../../../data/vq/imagenet/train")
print(f"Number of samples: {len(dataset)}")
row = dataset[0]
import ipdb; ipdb.set_trace()  # noqa # fmt: skip
