#%%
import os
from pathlib import Path

import torch
from torchvision.datasets import ImageNet

imagenet_path = Path(os.environ["IMAGENET_PATH"])
imagenet_ds = ImageNet(imagenet_path, split='val')

sampling_step = 50
def get_imagenet_relative_path(path: str):
    # Example: path = ~/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG
    path = Path(path)
    base_path = path.parent.parent.parent # ~/imagenet/
    relpath = path.relative_to(base_path) # val/n01440764/ILSVRC2012_val_00000293.JPEG
    return relpath

subset_paths = [str(get_imagenet_relative_path(x[0])) for x in imagenet_ds.imgs[::sampling_step]]
output = "\n".join(subset_paths)
output_path = Path(__file__).parent.parent

output_path = output_path / f"imagenet_subset_{len(subset_paths)}.txt"

print(f"Writing to {output_path}")
output_path.write_text(output)
print("Done")

# %%
import torch.utils.data


class ImageNetWithPaths(torch.utils.data.Dataset):
    def __init__(self, imagenet_ds):
        self.imagenet_ds = imagenet_ds

    def __getitem__(self, index: int):
        sample, target = self.imagenet_ds[index]
        path = self.imagenet_ds.imgs[index]
        return sample, target, path

inwp = ImageNetWithPaths(imagenet_ds)
inwp[10000]
# %%
len(inwp)