import os
from pathlib import Path
import shutil

import torch
import torch.utils.data as Data
import torchvision.transforms as transforms

from PIL import Image


DATA_ROOT = os.environ.get("DATA_ROOT", '/project/')


def parse_imagenet(data_root: str,
                   num_classes: int = 1000,
                   seed: int = 2025):
    seed0 = torch.seed()
    torch.manual_seed(seed)
    selected_classes = torch.randperm(1000)[:num_classes]
    torch.manual_seed(seed0)
    selected_classes = set(selected_classes.tolist())
    lookup = {}
    for idx, label in enumerate(selected_classes):
        lookup[label] = idx

    data_lists = []
    for mode in 'train', 'val':
        data_file = f'{data_root}/imagenet/meta/{mode}.txt'
        data_list = []
        with open(data_file) as f:
            lines = f.readlines()
            for line in lines:
                path, label = line.split()
                label = int(label)
                if label not in lookup:
                    continue
                path = f'{data_root}/imagenet/{mode}/{path}'
                record = (path, lookup[label])
                data_list.append(record)
        data_lists.append(tuple(data_list))

    return data_lists


class imagenet_loader(Data.Dataset):
    def __init__(self, datalist, data_root, mode):
        assert mode in ['train', 'val']

        self.datalist = datalist

        if mode == 'train':
            self.transform = transforms.Compose([
                transforms.RandAugment(),
                transforms.RandomResizedCrop(224, scale=(0.3, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])
        elif mode == 'val':
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor()
            ])


    def __getitem__(self, index):
        path, label = self.datalist[index]
        image = self.read_image(path)
        image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.datalist)

    def read_image(self, path):
        # cache image in /dev/shm/
        shm_path = Path(path.replace(DATA_ROOT, "/dev/shm"))

        if not shm_path.exists():
            shm_path.parent.mkdir(parents=True, exist_ok=True)
            shutil.copyfile(path, shm_path)
       
        try:
            image = Image.open(shm_path)
        except:
            image = Image.open(path)
        return image.convert('RGB')


def imagenet_dataset(num_classes=1000, 
                     data_root=DATA_ROOT, 
                     seed=2025):
    train_list, val_list = parse_imagenet(data_root, num_classes, seed)
    train_dataset = imagenet_loader(train_list, data_root, mode='train')
    val_dataset = imagenet_loader(val_list, data_root, mode='val')
    return train_dataset, val_dataset


