import numpy as np
import os
import pickle
import torch
from PIL import Image
from torch.utils.data import Dataset
from utils.api import check_exists, makedir_exist_ok, save, load
from .utils import download_url, extract_file, make_classes_counts


class CINIC10(Dataset):
    data_name = 'CINIC10'
    file = [('https://datashare.is.ed.ac.uk/bitstream/handle/10283/3192/CINIC-10.tar.gz', None)]

    def __init__(self, root, split, transform=None):
        self.root = os.path.expanduser(root)
        self.split = split
        self.transform = transform
        self.malicious_data_ids = None
        if not check_exists(self.processed_folder):
            self.process()
        self.id, self.data, self.target = load(os.path.join(self.processed_folder, '{}.pt'.format(self.split)),
                                               mode='pickle')
        
        # self.id = self.id[:int(len(self.id)/5000)]
        # self.data = self.data[:int(len(self.data)/5000)]
        # self.target = self.target[:int(len(self.target)/5000)]

        self.classes_counts = make_classes_counts(self.target)
        self.classes_to_labels, self.target_size = load(os.path.join(self.processed_folder, 'meta.pt'), mode='pickle')

    def __getitem__(self, index):
        id = torch.tensor(self.id[index])
        temp = self.data[index]
        data = Image.fromarray(self.data[index])
        target = torch.tensor(self.target[index])
        id, data, target = torch.tensor(self.id[index]), Image.fromarray(self.data[index].astype(np.uint8)), torch.tensor(
            self.target[index])
        input = {'id': id, 'data': data, 'target': target}
        if self.transform is not None:
            input = self.transform(input, self.split, 'CIFAR10', self.malicious_data_ids)
        return input

    def __len__(self):
        return len(self.data)

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed')

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    def process(self):
        if not check_exists(self.raw_folder):
            self.download()
        train_set, test_set, meta = self.make_data()
        save(train_set, os.path.join(self.processed_folder, 'train.pt'), mode='pickle')
        save(test_set, os.path.join(self.processed_folder, 'test.pt'), mode='pickle')
        save(meta, os.path.join(self.processed_folder, 'meta.pt'), mode='pickle')
        return

    def download(self):
        makedir_exist_ok(self.raw_folder)
        for (url, md5) in self.file:
            filename = os.path.basename(url)
            download_url(url, os.path.join(self.raw_folder, filename))
            extract_file(os.path.join(self.raw_folder, filename))
        return

    def __repr__(self):
        fmt_str = 'Dataset {}\nSize: {}\nRoot: {}\nSplit: {}\nTransforms: {}'.format(
            self.__class__.__name__, self.__len__(), self.root, self.split, self.transform.__repr__())
        return fmt_str

    def make_data(self):
        # train_filenames = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
        # test_filenames = ['test_batch']
        # train_data, train_target = read_pickle_file(
        #     os.path.join(self.raw_folder, 'cifar-10-batches-py'),
        #     train_filenames
        # )
        # test_data, test_target = read_pickle_file(
        #     os.path.join(self.raw_folder, 'cifar-10-batches-py'), 
        #     test_filenames
        # )
        # train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        # with open(os.path.join(self.raw_folder, 'cifar-10-batches-py', 'batches.meta'), 'rb') as f:
        #     data = pickle.load(f, encoding='latin1')
        #     classes = data['label_names']
        # classes_to_labels = {classes[i]: i for i in range(len(classes))}
        # target_size = len(classes)
        # return (train_id, train_data, train_target), (test_id, test_data, test_target), (classes_to_labels, target_size)


        cinic_directory = "../data/cinic-10"
        enlarge_directory = "../data/cinic-10-trainlarge"
        classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
        sets = ['train', 'valid', 'test']
        # if not os.path.exists(enlarge_directory):
        #     os.makedirs(enlarge_directory)
        # if not os.path.exists(enlarge_directory + '/train'):
        #     os.makedirs(enlarge_directory + '/train')
        # if not os.path.exists(enlarge_directory + '/test'):
        #     os.makedirs(enlarge_directory + '/test')
            
        for c in classes:
            if not os.path.exists('{}/train/{}'.format(enlarge_directory, c)):
                os.makedirs('{}/train/{}'.format(enlarge_directory, c))
            if not os.path.exists('{}/test/{}'.format(enlarge_directory, c)):
                os.makedirs('{}/test/{}'.format(enlarge_directory, c))
        from torchvision.io import read_image
        import skimage
        import glob
        classes_to_labels = {classes[i]: i for i in range(len(classes))}
        train_data = []
        test_data = []
        train_target = []
        test_target = []
        for s in sets:
            for c in classes:
                source_directory = os.path.join(self.raw_folder, s, c)
                # source_directory = '{}/{}/{}'.format(cinic_directory, s, c)
                filenames = glob.glob('{}/*.png'.format(source_directory))
                # filenames = os.path.join(source_directory, filenames)

                if s == 'test':
                    filenames = filenames[:int(1/9 * len(filenames))]
                    a = 5
                for fn in filenames:
                    dest_fn = fn.split('/')[-1]
                    # if s == 'train' or s == 'valid':
                    if s == 'train':
                        # dest_fn = '{}/train/{}/{}'.format(enlarge_directory, c, dest_fn)
                        # if symlink:
                        #     if not os.path.islink(dest_fn):
                        #         os.symlink(fn, dest_fn)
                        # else:
                        #     copyfile(fn, dest_fn)

                        # image = read_image(fn)
                        image = skimage.io.imread(fn)
                        # a = image.shape
                        if len(image.shape) == 2:
                            image = skimage.color.gray2rgb(image, channel_axis=-1)
                        # image = image.transpose(2, 0, 1)
                        train_data.append(image)
                        train_target.append(classes_to_labels[c])
                        
                    elif s == 'test':
                        # dest_fn = '{}/test/{}/{}'.format(enlarge_directory, c, dest_fn)
                        # if symlink:
                        #     if not os.path.islink(dest_fn):
                        #         os.symlink(fn, dest_fn)
                        # else:
                        #     copyfile(fn, dest_fn)
                        # image = read_image(fn)
                        image = skimage.io.imread(fn)
                        if len(image.shape) == 2:
                            image = skimage.color.gray2rgb(image, channel_axis=-1)
                        # image = image.transpose(2, 0, 1)
                        test_data.append(image)
                        test_target.append(classes_to_labels[c])

        train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
        
        target_size = len(classes)
        return (train_id, train_data, train_target), (test_id, test_data, test_target), (classes_to_labels, target_size)

# class CIFAR100(CIFAR10):
#     data_name = 'CIFAR100'
#     file = [('https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz', 'eb9058c3a382ffc7106e4002c42a8d85')]

#     def make_data(self):
#         train_filenames = ['train']
#         test_filenames = ['test']
#         train_data, train_target = read_pickle_file(os.path.join(self.raw_folder, 'cifar-100-python'), train_filenames)
#         test_data, test_target = read_pickle_file(os.path.join(self.raw_folder, 'cifar-100-python'), test_filenames)
#         train_id, test_id = np.arange(len(train_data)).astype(np.int64), np.arange(len(test_data)).astype(np.int64)
#         with open(os.path.join(self.raw_folder, 'cifar-100-python', 'meta'), 'rb') as f:
#             data = pickle.load(f, encoding='latin1')
#             classes = data['fine_label_names']
#         classes_to_labels = {classes[i]: i for i in range(len(classes))}
#         target_size = len(classes)
#         return (train_id, train_data, train_target), (test_id, test_data, test_target), (classes_to_labels, target_size)


# def read_pickle_file(path, filenames):
#     img, label = [], []
#     for filename in filenames:
#         file_path = os.path.join(path, filename)
#         with open(file_path, 'rb') as f:
#             entry = pickle.load(f, encoding='latin1')
#             img.append(entry['data'])
#             label.extend(entry['labels']) if 'labels' in entry else label.extend(entry['fine_labels'])
#     img = np.vstack(img).reshape(-1, 3, 32, 32)
#     img = img.transpose((0, 2, 3, 1))
#     label = np.array(label).astype(np.int64)
#     return img, label


# CIFAR100_classes = {
#     'aquatic mammals': ['beaver', 'dolphin', 'otter', 'seal', 'whale'],
#     'fish': ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
#     'flowers': ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
#     'food containers': ['bottle', 'bowl', 'can', 'cup', 'plate'],
#     'fruit and vegetables': ['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'],
#     'household electrical devices': ['clock', 'keyboard', 'lamp', 'telephone', 'television'],
#     'household furniture': ['bed', 'chair', 'couch', 'table', 'wardrobe'],
#     'insects': ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
#     'large carnivores': ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
#     'large man-made outdoor things': ['bridge', 'castle', 'house', 'road', 'skyscraper'],
#     'large natural outdoor scenes': ['cloud', 'forest', 'mountain', 'plain', 'sea'],
#     'large omnivores and herbivores': ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
#     'medium-sized mammals': ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
#     'non-insect invertebrates': ['crab', 'lobster', 'snail', 'spider', 'worm'],
#     'people': ['baby', 'boy', 'girl', 'man', 'woman'],
#     'reptiles': ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
#     'small mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
#     'trees': ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'],
#     'vehicles 1': ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
#     'vehicles 2': ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']
# }
