# -*- coding: utf-8 -*-
"""

@author: Anonymous Author
"""

import numpy as np
import torch.utils.data as data
import torch
import os
from glob import glob
from PIL import Image

class ImageNetSubset(data.Dataset):
    def __init__(self, subset_file, root, index, split='train', 
                    transform=None):
        super(ImageNetSubset, self).__init__()

        self.root = os.path.join(root,  split)
        self.transform = transform
        self.split = split

        # Read the subset of classes to include (sorted)
        with open(subset_file, 'r') as f:
            result = f.read().splitlines()
        subdirs, class_names = [], []
        for line in result:
            subdir, class_name = line.split(' ', 1)
            subdirs.append(subdir)
            class_names.append(class_name)

        # Gather the files (sorted)
        imgs = []
        for i, subdir in enumerate(subdirs):
            # subdir_path = os.path.join(self.root, subdir)
            files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG')))
            ### check if order is consistent on different devices, done
            # with open('C:\\document\\data\\check\\f' + str(i) + '.txt','w') as f:
            #     for ittt in files:
            #         f.write(ittt +'\n')
            for f in files:
                imgs.append((f, i)) 
        # self.imgs = imgs 
        self.classes = class_names
        
        if index is not None:
            self.imgs = [imgs[i] for i in index]
        else:
            self.imgs = imgs

    def get_image(self, index):
        path, target = self.imgs[index]
        with open(path, 'rb') as f:
            img = Image.open(f).convert('RGB') 
            
        return img

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        path, target = self.imgs[index]
        with open(path, 'rb') as f:
            img = Image.open(f).convert('RGB')
        im_size = img.size
        
        class_name = self.classes[target]

        if self.transform is not None:
            img = self.transform(img)

        out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index, 'class_name': class_name}}

        return out
