from torchvision.datasets import ImageFolder
import socket
import pickle
import os

_hostname = socket.gethostname()

def imagenet(root, train, transform=None, download=True):
    if _hostname in ["zebra", "server231-SYS-4028GR-TR"]:
        if train:
            path = os.path.join(root, "train")
        else:
            path = os.path.join(root, "val")
        ds = ImageFolder(root=path, transform=transform)
    else:
        try:
            from ffrecord.torch import Dataset
        except:
            print(f"Warning-{__file__}: No ffrecord installed!")
        
        class FireFlyerImageNet(Dataset):
            def __init__(self, root, transform=None):
                super(FireFlyerImageNet, self).__init__(root, check_data=True)
                self.transform = transform
            
            def process(self, indexes, data):
                samples = []

                for bytes_ in data:
                    img, label = pickle.loads(bytes_)
                    if self.transform:
                        img = self.transform(img)
                    samples.append((img, label))

                # default collate_fn would handle them
                return samples
        
        if train:
            path = os.path.join(root, "train.ffr")
        else:
            path = os.path.join(root, "val.ffr")
        ds = FireFlyerImageNet(root=path, transform=transform)

    return ds