import os
import sys
sys.path.insert(0, './')
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from .Utility import SubsetRandomSampler, SubsetSampler, MislabelDataset


class IndexedImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform, split):
        super(IndexedImageFolder, self).__init__(root = os.path.join(root, split), transform = transform)

    def __getitem__(self, index):
        img, target = super(IndexedImageFolder, self).__getitem__(index)
        return img, target, index


class MetaClassImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform, split, class_list):
        self.class_list = class_list  # remaining classes
        super(MetaClassImageFolder, self).__init__(os.path.join(root, split), transform=transform)

        # filter out the classes not in class_list
        classes_to_keep = set(class_list)
        self.samples = [(s, t) for s, t in self.samples if self.classes[t] in classes_to_keep]
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}

        # rebuild class_to_idx mapping
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_list)}
        self.classes = class_list

        # remap the targets to the new class indices
        self.targets = [self.class_to_idx[self.idx_to_class[t]] for s, t in self.samples]

        # rebuild the idx_to_class mapping
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}

    def __getitem__(self, index):
        img, target = super(MetaClassImageFolder, self).__getitem__(index)
        return img, target, index


def imagenet(batch_size, root = './data/imagenet', valid_ratio = None, shuffle = True, augmentation = True, train_subset = None, test_subset=None,
             is_split=False, split_seed=0, is_shadow=False, shadow_ratio=0.8, shadow_seed=0,
             mislabel_ratio=0., mislabel_seed=0, class_subset_path=None,
             member_train=False, member_test=False, nonmember_train=False, nonmember_test=False, num_worker=0):
    '''
     batch_size: batch size.
     root: where data is stored.
     valid_ratio: the ratio of validation data, None if no validation set.
     shuffle: whether or not the training set is shuffled.
     augmentation: whether or not the augmentation is applied.
     train_subset: the specified subset for training, None if we use the whole training set.    
    '''

    if member_train and member_test:
        raise ValueError('member_train and member_test cannot be both True.')
    if nonmember_train and nonmember_test:
        raise ValueError('nonmember_train and nonmember_test cannot be both True.')

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        ])  if augmentation == True else transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        ])
    transform_valid = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        ])
    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        ])

    if class_subset_path is not None:
        # load the class subset from the csv file
        df = pd.read_csv(class_subset_path)
        class_subset = df['ID'].tolist()
        trainset = MetaClassImageFolder(root = root, transform = transform_train, split = 'train', class_list = class_subset)
        testset = MetaClassImageFolder(root = root, transform = transform_valid, split = 'val', class_list = class_subset)
        validset = MetaClassImageFolder(root = root, transform = transform_test, split = 'train', class_list = class_subset)
    else:
        trainset = IndexedImageFolder(root = root, split = 'train', transform = transform_train)
        testset = IndexedImageFolder(root = root, split = 'val', transform = transform_valid)
        validset = IndexedImageFolder(root = root, split = 'train', transform = transform_test)

    # wrap the dataset with mislabel
    if mislabel_ratio > 0:
        trainset = MislabelDataset(trainset, num_class=len(trainset.class_to_idx), mislabel_ratio = mislabel_ratio, mislabel_seed=mislabel_seed)
        testset = MislabelDataset(testset, num_class=len(testset.class_to_idx), mislabel_ratio = mislabel_ratio, mislabel_seed=mislabel_seed)
        validset = MislabelDataset(validset, num_class=len(validset.class_to_idx), mislabel_ratio = mislabel_ratio, mislabel_seed=mislabel_seed)

    classes = list(range(len(trainset.class_to_idx)))

    if is_split:
        num_train_data, num_test_data = len(trainset), len(testset)
        np.random.seed(split_seed)
        train_subset = np.random.choice(num_train_data, size=num_train_data//2, replace=False)
        np.random.seed(split_seed)
        test_subset = np.random.choice(num_test_data, size=num_test_data//2, replace=False)
        if is_shadow:
            train_subset = set(np.arange(num_train_data)) - set(train_subset)
            train_subset = np.array(list(train_subset))
            test_subset = set(np.arange(num_test_data)) - set(test_subset)
            test_subset = np.array(list(test_subset))
            np.random.seed(shadow_seed)
            train_subset = np.random.choice(train_subset, size=int(len(train_subset) * shadow_ratio), replace=False)
            np.random.seed(shadow_seed)
            test_subset = np.random.choice(test_subset, size=int(len(test_subset) * shadow_ratio), replace=False)

    if train_subset is None:
        train_indices = list(range(len(trainset)))
    else:
        train_indices = np.random.permutation(train_subset)
    if member_train:
        train_indices = train_indices[:len(train_indices) // 2]
    if member_test:
        train_indices = train_indices[len(train_indices) // 2:]
    train_instance_num = len(train_indices)
    print('%d instances are picked from the training set' % train_instance_num)

    if test_subset is None:
        test_indices = list(range(len(testset)))
    else:
        test_indices = test_subset
    if nonmember_train:
        test_indices = test_indices[:len(test_indices) // 2]
    if nonmember_test:
        test_indices = test_indices[len(test_indices) // 2:]
    test_instance_num = len(test_indices)
    print('%d instances are picked from the test set' % test_instance_num)
    test_sampler = SubsetSampler(test_indices)

    if valid_ratio is not None and valid_ratio > 0.:
        split_pt = int(train_instance_num * valid_ratio)
        train_idx, valid_idx = train_indices[split_pt:], train_indices[:split_pt]

        if shuffle == True:
            train_sampler, valid_sampler = SubsetRandomSampler(train_idx), SubsetSampler(valid_idx)
        else:
            train_sampler, valid_sampler = SubsetSampler(train_idx), SubsetSampler(valid_idx)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = torch.utils.data.DataLoader(validset, batch_size = batch_size, sampler = valid_sampler, num_workers = num_worker, pin_memory = True)
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    else:
        if shuffle == True:
            train_sampler = SubsetRandomSampler(train_indices)
        else:
            train_sampler = SubsetSampler(train_indices)

        train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, sampler = train_sampler, num_workers = num_worker, pin_memory = True)
        valid_loader = None
        test_loader = torch.utils.data.DataLoader(testset, batch_size = batch_size, sampler = test_sampler, shuffle = False, num_workers = num_worker, pin_memory = True)

    return train_loader, valid_loader, test_loader, classes


if __name__ == '__main__':
    train_loader, valid_loader, test_loader, classes = imagenet(100, root = '/data/imagenet',
                                                                class_subset_path='../config/data/metaclass_member.csv')