import os
import sys
sys.path.insert(0, './')
import numpy as np

import torch
import torch.nn as nn
from torchvision import datasets, transforms

from .Utility import SubsetRandomSampler, SubsetSampler, MislabelDataset

SUPERCLASS_DICT = {
            'aquatic mammals': ['beaver', 'dolphin', 'otter', 'seal', 'whale'],
            'fish': ['aquarium fish', 'flatfish', 'ray', 'shark', 'trout'],
            'flowers': ['orchids', 'poppies', 'roses', 'sunflowers', 'tulips'],
            'food containers': ['bottles', 'bowls', 'cans', 'cups', 'plates'],
            'fruit and vegetables': ['apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers'],
            'household electrical devices': ['clock', 'computer keyboard', 'lamp', 'telephone', 'television'],
            'household furniture': ['bed', 'chair', 'couch', 'table', 'wardrobe'],
            'insects': ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
            'large carnivores': ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
            'large man-made outdoor things': ['bridge', 'castle', 'house', 'road', 'skyscraper'],
            'large natural outdoor scenes': ['cloud', 'forest', 'mountain', 'plain', 'sea'],
            'large omnivores and herbivores': ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
            'medium-sized mammals': ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
            'non-insect invertebrates': ['crab', 'lobster', 'snail', 'spider', 'worm'],
            'people': ['baby', 'boy', 'girl', 'man', 'woman'],
            'reptiles': ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
            'small mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
            'trees': ['maple', 'oak', 'palm', 'pine', 'willow'],
            'vehicles 1': ['bicycle', 'bus', 'motorcycle', 'pickup truck', 'train'],
            'vehicles 2': ['lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor']
        }


class IndexedCIFAR100(datasets.CIFAR100):
    def __init__(self, root, train, download, transform, superclass=False):
        super(IndexedCIFAR100, self).__init__(root = root, train = train, transform = transform, download = download)
        self.superclass = superclass
        self.subclass_to_superclass = {
            0: 4, 1: 1, 2: 14, 3: 8, 4: 0, 5: 6, 6: 7, 7: 7, 8: 18, 9: 3,
            10: 3, 11: 14, 12: 9, 13: 18, 14: 7, 15: 11, 16: 3, 17: 9, 18: 7, 19: 11,
            20: 6, 21: 11, 22: 5, 23: 10, 24: 7, 25: 6, 26: 13, 27: 15, 28: 3, 29: 15,
            30: 0, 31: 11, 32: 1, 33: 10, 34: 12, 35: 14, 36: 16, 37: 9, 38: 11, 39: 5,
            40: 5, 41: 19, 42: 8, 43: 8, 44: 15, 45: 13, 46: 14, 47: 17, 48: 18, 49: 10,
            50: 16, 51: 4, 52: 17, 53: 4, 54: 2, 55: 0, 56: 17, 57: 4, 58: 18, 59: 17,
            60: 10, 61: 3, 62: 2, 63: 12, 64: 12, 65: 16, 66: 12, 67: 1, 68: 9, 69: 19,
            70: 2, 71: 10, 72: 0, 73: 1, 74: 16, 75: 12, 76: 9, 77: 13, 78: 15, 79: 13,
            80: 16, 81: 19, 82: 2, 83: 4, 84: 6, 85: 19, 86: 5, 87: 5, 88: 8, 89: 19,
            90: 18, 91: 1, 92: 2, 93: 15, 94: 6, 95: 0, 96: 17, 97: 8, 98: 14, 99: 13
        }

    def __getitem__(self, index):
        img, target = super(IndexedCIFAR100, self).__getitem__(index)
        if self.superclass:
            target = self.subclass_to_superclass[target]
        return img, target, index


def cifar100(batch_size, root='./data/cifar100', valid_ratio=None, shuffle=True, augmentation=True, train_subset=None, test_subset=None,
             mislabel_ratio=0., mislabel_seed=0, class_subset_path=None, is_split=False, split_seed=0, is_shadow=False, shadow_ratio=0.8, shadow_seed=0,
             member_train=False, member_test=False, nonmember_train=False, nonmember_test=False, num_worker=0):
    '''
    batch_size: batch size.
    root: where data is stored.
     valid_ratio: the ratio of validation data, None if no validation set.
     shuffle: whether or not the training set is shuffled.
     augmentation: whether or not the augmentation is applied.
     train_subset: the specified subset for training, None if we use the whole training set.
    '''

    assert class_subset_path is None, 'Class subset is not supported for CIFAR-100 dataset.'
    if member_train and member_test:
        raise ValueError('member_train and member_test cannot be both True.')
    if nonmember_train and nonmember_test:
        raise ValueError('nonmember_train and nonmember_test cannot be both True.')

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding = 4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
        ]) if augmentation == True else transforms.Compose([
        transforms.ToTensor()
        ])
    transform_valid = transforms.Compose([
        transforms.ToTensor()
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor()
        ])


    trainset = IndexedCIFAR100(root = root, train = True, download = True, transform = transform_train)
    validset = IndexedCIFAR100(root = root, train = True, download = True, transform = transform_valid)
    testset = IndexedCIFAR100(root = root, train = False, download = True, transform = transform_test)

    # wrap the dataset with mislabel
    if mislabel_ratio > 0:
        trainset = MislabelDataset(trainset, num_class=100, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)
        testset = MislabelDataset(testset, num_class=100, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)
        validset = MislabelDataset(validset, num_class=100, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)

    classes = list(range(100))

    if is_split:
        num_train_data, num_test_data = len(trainset), len(testset)
        np.random.seed(split_seed)
        train_subset = np.random.choice(num_train_data, size=num_train_data//2, replace=False)
        np.random.seed(split_seed)
        test_subset = np.random.choice(num_test_data, size=num_test_data//2, replace=False)
        if is_shadow:
            train_subset = set(np.arange(num_train_data)) - set(train_subset)
            train_subset = np.array(list(train_subset))
            test_subset = set(np.arange(num_test_data)) - set(test_subset)
            test_subset = np.array(list(test_subset))
            np.random.seed(shadow_seed)
            train_subset = np.random.choice(train_subset, size=int(len(train_subset) * shadow_ratio), replace=False)
            np.random.seed(shadow_seed)
            test_subset = np.random.choice(test_subset, size=int(len(test_subset) * shadow_ratio), replace=False)

    if train_subset is None:
        train_indices = list(range(len(trainset)))
    else:
        train_indices = np.random.permutation(train_subset)
    if member_train:
        train_indices = train_indices[:len(train_indices) // 2]
    if member_test:
        train_indices = train_indices[len(train_indices) // 2:]
    train_instance_num = len(train_indices)
    print('%d instances are picked from the training set' % train_instance_num)

    if test_subset is None:
        test_indices = list(range(len(testset)))
    else:
        test_indices = test_subset
    if nonmember_train:
        test_indices = test_indices[:len(test_indices) // 2]
    if nonmember_test:
        test_indices = test_indices[len(test_indices) // 2:]
    test_instance_num = len(test_indices)
    print('%d instances are picked from the test set' % test_instance_num)
    test_sampler = SubsetSampler(test_indices)

    if valid_ratio is not None and valid_ratio > 0.:
        split_pt = int(train_instance_num * valid_ratio)
        train_idx, valid_idx = train_indices[split_pt:], train_indices[:split_pt]

        if shuffle == True:
            train_sampler, valid_sampler = SubsetRandomSampler(train_idx), SubsetSampler(valid_idx)
        else:
            train_sampler, valid_sampler = SubsetSampler(train_idx), SubsetSampler(valid_idx)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = torch.utils.data.DataLoader(validset, batch_size = batch_size, sampler = valid_sampler, num_workers = num_worker, pin_memory = True)
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    else:
        if shuffle == True:
            train_sampler = SubsetRandomSampler(train_indices)
        else:
            train_sampler = SubsetSampler(train_indices)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = None
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    return train_loader, valid_loader, test_loader, classes


def cifar100_superclass(batch_size, root='./data/cifar100', valid_ratio=None, shuffle=True, augmentation=True, train_subset=None, test_subset=None,
             mislabel_ratio=0., mislabel_seed=0, class_subset_path=None, is_split=False, split_seed=0, is_shadow=False, shadow_ratio=0.8, shadow_seed=0,
             member_train=False, member_test=False, nonmember_train=False, nonmember_test=False, num_worker=0):
    '''
    batch_size: batch size.
    root: where data is stored.
     valid_ratio: the ratio of validation data, None if no validation set.
     shuffle: whether or not the training set is shuffled.
     augmentation: whether or not the augmentation is applied.
     train_subset: the specified subset for training, None if we use the whole training set.
    '''

    assert class_subset_path is None, 'Class subset is not supported for CIFAR-100 dataset.'
    if member_train and member_test:
        raise ValueError('member_train and member_test cannot be both True.')
    if nonmember_train and nonmember_test:
        raise ValueError('nonmember_train and nonmember_test cannot be both True.')

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding = 4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
        ]) if augmentation == True else transforms.Compose([
        transforms.ToTensor()
        ])
    transform_valid = transforms.Compose([
        transforms.ToTensor()
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor()
        ])


    trainset = IndexedCIFAR100(root = root, train = True, download = True, transform = transform_train, superclass=True)
    validset = IndexedCIFAR100(root = root, train = True, download = True, transform = transform_valid, superclass=True)
    testset = IndexedCIFAR100(root = root, train = False, download = True, transform = transform_test, superclass=True)

    # wrap the dataset with mislabel
    if mislabel_ratio > 0:
        trainset = MislabelDataset(trainset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)
        testset = MislabelDataset(testset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)
        validset = MislabelDataset(validset, num_class=10, mislabel_ratio=mislabel_ratio, mislabel_seed=mislabel_seed)

    classes = list(range(100))

    if is_split:
        num_train_data, num_test_data = len(trainset), len(testset)
        np.random.seed(split_seed)
        train_subset = np.random.choice(num_train_data, size=num_train_data//2, replace=False)
        np.random.seed(split_seed)
        test_subset = np.random.choice(num_test_data, size=num_test_data//2, replace=False)
        if is_shadow:
            train_subset = set(np.arange(num_train_data)) - set(train_subset)
            train_subset = np.array(list(train_subset))
            test_subset = set(np.arange(num_test_data)) - set(test_subset)
            test_subset = np.array(list(test_subset))
            np.random.seed(shadow_seed)
            train_subset = np.random.choice(train_subset, size=int(len(train_subset) * shadow_ratio), replace=False)
            np.random.seed(shadow_seed)
            test_subset = np.random.choice(test_subset, size=int(len(test_subset) * shadow_ratio), replace=False)

    if train_subset is None:
        train_indices = list(range(len(trainset)))
    else:
        train_indices = np.random.permutation(train_subset)
    if member_train:
        train_indices = train_indices[:len(train_indices) // 2]
    if member_test:
        train_indices = train_indices[len(train_indices) // 2:]
    train_instance_num = len(train_indices)
    print('%d instances are picked from the training set' % train_instance_num)

    if test_subset is None:
        test_indices = list(range(len(testset)))
    else:
        test_indices = test_subset
    if nonmember_train:
        test_indices = test_indices[:len(test_indices) // 2]
    if nonmember_test:
        test_indices = test_indices[len(test_indices) // 2:]
    test_instance_num = len(test_indices)
    print('%d instances are picked from the test set' % test_instance_num)
    test_sampler = SubsetSampler(test_indices)

    if valid_ratio is not None and valid_ratio > 0.:
        split_pt = int(train_instance_num * valid_ratio)
        train_idx, valid_idx = train_indices[split_pt:], train_indices[:split_pt]

        if shuffle == True:
            train_sampler, valid_sampler = SubsetRandomSampler(train_idx), SubsetSampler(valid_idx)
        else:
            train_sampler, valid_sampler = SubsetSampler(train_idx), SubsetSampler(valid_idx)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = torch.utils.data.DataLoader(validset, batch_size = batch_size, sampler = valid_sampler, num_workers = num_worker, pin_memory = True)
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    else:
        if shuffle == True:
            train_sampler = SubsetRandomSampler(train_indices)
        else:
            train_sampler = SubsetSampler(train_indices)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = None
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    return train_loader, valid_loader, test_loader, classes