from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

class Imagenet32(Dataset):
    """`Imagenet32 Dataset.

    """
    train_list = [
        ['train_data_batch_1'],
        ['train_data_batch_2'],
        ['train_data_batch_3'],
        ['train_data_batch_4'],
        ['train_data_batch_5'],
        ['train_data_batch_6'],
        ['train_data_batch_7'],
        ['train_data_batch_8'],
        ['train_data_batch_9'],
        ['train_data_batch_10'],
    ]
    test_list = [
        ['val_data'],
    ]

    def __init__(self, root, train=True, download=None,Augmentation=True):
        self.train = train  # training set or test set
        self.toTensor = transforms.ToTensor()
        self.dataaugmentation = transforms.Compose([ 
                                        transforms.RandomCrop(32, padding=4),  
                                        transforms.Scale(32),  
                                        transforms.RandomHorizontalFlip(),  
                                    ])
        self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        if self.train:
            data_list = self.train_list
        else:
            data_list = self.test_list

        self.data = []
        self.targets = []
        self.root = root
        self.Augmentation = Augmentation
        # now load the picked numpy arrays
        for file_name in data_list:
            file_path = os.path.join(self.root,file_name[0])
            with open(file_path, 'rb') as f:
                entry = pickle.load(f)
                self.data.append(entry['data'])
                self.targets.extend(entry['labels'])
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        img = self.data_preproccess(img)
        return img, np.array(target)-1

    def __len__(self):
        return len(self.data)

    def data_preproccess(self, data):
        if self.Augmentation and self.train:
            data = self.dataaugmentation(data)
        data = self.toTensor(data)
        data = self.normalize(data)
        return data



# class TorchDataset(Dataset):
#     def __init__(self, filename, image_dir, resize_height=256, resize_width=256, repeat=1):
#         '''
#         :param filename: 数据文件TXT：格式：imge_name.jpg label1_id labe2_id
#         :param image_dir: 图片路径：image_dir+imge_name.jpg构成图片的完整路径
#         :param resize_height 为None时，不进行缩放
#         :param resize_width  为None时，不进行缩放，
#                               PS：当参数resize_height或resize_width其中一个为None时，可实现等比例缩放
#         :param repeat: 所有样本数据重复次数，默认循环一次，当repeat为None时，表示无限循环<sys.maxsize
#         '''
#         self.image_label_list = self.read_file(filename)
#         self.image_dir = image_dir
#         self.len = len(self.image_label_list)
#         self.repeat = repeat
#         self.resize_height = resize_height
#         self.resize_width = resize_width
 
#         self.toTensor = transforms.ToTensor()

 
#     def __getitem__(self, i):
#         index = i % self.len
#         # print("i={},index={}".format(i, index))
#         image_name, label = self.image_label_list[index]
#         image_path = os.path.join(self.image_dir, image_name)
#         img = self.load_data(image_path, self.resize_height, self.resize_width, normalization=False)
#         img = self.data_preproccess(img)
#         label=np.array(label)
#         return img, label
 
#     def __len__(self):
#         if self.repeat == None:
#             data_len = 10000000
#         else:
#             data_len = len(self.image_label_list) * self.repeat
#         return data_len
 
#     def read_file(self, filename):
#         image_label_list = []
#         with open(filename, 'r') as f:
#             lines = f.readlines()
#             for line in lines:
#                 # rstrip：用来去除结尾字符、空白符(包括\n、\r、\t、' '，即：换行、回车、制表符、空格)
#                 content = line.rstrip().split(' ')
#                 name = content[0]
#                 labels = []
#                 for value in content[1:]:
#                     labels.append(int(value))
#                 image_label_list.append((name, labels))
#         return image_label_list
 
#     def load_data(self, path, resize_height, resize_width, normalization):
#         '''
#         加载数据
#         :param path:
#         :param resize_height:
#         :param resize_width:
#         :param normalization: 是否归一化
#         :return:
#         '''
#         image = image_processing.read_image(path, resize_height, resize_width, normalization)
#         return image
 
#     def data_preproccess(self, data):
#         '''
#         数据预处理
#         :param data:
#         :return:
#         '''
#         data = self.toTensor(data)
#         return data


# train_loader = torch.utils.data.DataLoader(Imagenet32(root='data/imagenet/',train=True) ,batch_size=256, shuffle=True, num_workers=2,pin_memory=True)
