import os
import numpy as np
import torch
from torch.utils.data.dataset import Subset
from torchvision import datasets, transforms
import random
from sklearn.model_selection import train_test_split


def get_subset_with_len(dataset, length, shuffle=False):
    set_random_seed(0)
    dataset_size = len(dataset)

    index = np.arange(dataset_size)
    if shuffle:
        np.random.shuffle(index)

    index = torch.from_numpy(index[0:length])
    subset = Subset(dataset, index)

    assert len(subset) == length

    return subset


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


def get_transform(args):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    image_size = (224, 224, 3)

    normalize = transforms.Normalize(mean, std)

    train_transform = transforms.Compose([
        transforms.Resize((image_size[0], image_size[1])),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    test_transform = transforms.Compose([
        transforms.Resize((image_size[0], image_size[1])),
        transforms.ToTensor(),
        normalize,
    ])

    if args.id_dataset == 'cifar10':
        args.num_classes = 10
    elif args.id_dataset == 'cifar100':
        args.num_classes = 100
    else:
        raise NotImplementedError()
    return train_transform, test_transform


def get_dataset(args, dataset):

    train_transform, test_transform = get_transform(args)

    if args.id_dataset == 'cifar10':
        args.num_classes = 10
    elif args.id_dataset == 'cifar100':
        args.num_classes = 100
    elif args.id_dataset == 'food':
        args.num_classes = 101
    elif args.id_dataset == 'caltech':
        args.num_classes = 99

    if dataset == 'cifar10':
        train_set = datasets.CIFAR10(args.data_path, train=True,
                                     download=True, transform=train_transform)
        test_set = datasets.CIFAR10(args.data_path, train=False,
                                    download=True, transform=test_transform)
    elif dataset == 'cifar100':
        train_set = datasets.CIFAR100(args.data_path, train=True,
                                      download=True, transform=train_transform)
        test_set = datasets.CIFAR100(args.data_path, train=False,
                                     download=True, transform=test_transform)
    elif dataset == 'food':
        train_set = datasets.Food101(args.data_path, split='train',
                                     download=True, transform=train_transform)
        test_set = datasets.Food101(args.data_path, split='test',
                                    download=True, transform=test_transform)
    elif dataset == 'caltech':
        dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'caltech99/99_ObjectCategories'), transform=train_transform)
        train_set, _ = train_test_split(dataset_train, test_size=1000, random_state=0)
        dataset_test = datasets.ImageFolder(os.path.join(args.data_path, 'caltech99/99_ObjectCategories'), transform=test_transform)
        _, test_set = train_test_split(dataset_test, test_size=1000, random_state=0)

    else:
        raise NotImplementedError()

    return train_set, test_set


def get_dataset_eval(args, dataset):

    _, test_transform = get_transform(args)

    num_sample = 1000

    if args.id_dataset == 'cifar10':
        args.num_classes = 10
    elif args.id_dataset == 'cifar100':
        args.num_classes = 100
    elif args.id_dataset == 'food':
        args.num_classes = 101
    elif args.id_dataset == 'caltech':
        args.num_classes = 99

    if dataset == 'cifar10':
        train_set = datasets.CIFAR10(args.data_path, train=True,
                                     download=True, transform=test_transform)
        test_set = datasets.CIFAR10(args.data_path, train=False,
                                    download=True, transform=test_transform)
        test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)
        if args.id_dataset == 'cifar10':
            return train_set, test_set
        else:
            return test_set

    elif dataset == 'cifar100':
        train_set = datasets.CIFAR100(args.data_path, train=True,
                                      download=True, transform=test_transform)
        test_set = datasets.CIFAR100(args.data_path, train=False,
                                     download=True, transform=test_transform)
        test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)
        if args.id_dataset == 'cifar100':
            return train_set, test_set
        else:
            return test_set

    elif dataset == 'food':
        args.num_classes = 101
        train_set = datasets.Food101(args.data_path, split='train',
                                     download=True, transform=test_transform)
        test_set = datasets.Food101(args.data_path, split='test',
                                    download=True, transform=test_transform)
        test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)
        return train_set, test_set

    elif dataset == 'caltech':
        dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'caltech99/99_ObjectCategories'), transform=test_transform)
        train_set, _ = train_test_split(dataset_train, test_size=1000, random_state=0)
        dataset_test = datasets.ImageFolder(os.path.join(args.data_path, 'caltech99/99_ObjectCategories'), transform=test_transform)
        _, test_set = train_test_split(dataset_test, test_size=1000, random_state=0)
        return train_set, test_set

    elif dataset == 'lsun':
        test_dir = os.path.join(args.data_path, 'LSUN_resize')
        test_set = datasets.ImageFolder(test_dir, transform=test_transform)
        test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)
        return test_set

    elif dataset == 'isun':
        test_dir = os.path.join(args.data_path, 'iSUN')
        test_set = datasets.ImageFolder(test_dir, transform=test_transform)
        test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)
        return test_set

    elif dataset == 'imagenet30':
        if args.id_dataset == 'cifar10':
            test_dir = os.path.join(args.data_path, 'IN30_test_cifar10')
            test_set = datasets.ImageFolder(test_dir, transform=test_transform)
            if len(test_set) > num_sample:
                test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)
        elif args.id_dataset == 'cifar100':
            test_dir = os.path.join(args.data_path, 'IN30_test_cifar100')
            test_set = datasets.ImageFolder(test_dir, transform=test_transform)
            if len(test_set) > num_sample:
                test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)    
        elif args.id_dataset == 'food':
            test_dir = os.path.join(args.data_path, 'IN30_test_food')
            test_set = datasets.ImageFolder(test_dir, transform=test_transform)
            if len(test_set) > num_sample:
                test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)
        return test_set

    elif dataset == 'imagenet20':
        test_dir = os.path.join(args.data_path, 'IN20')
        test_set = datasets.ImageFolder(test_dir, transform=test_transform)
        if len(test_set) > num_sample:
            test_set = get_subset_with_len(test_set, length=num_sample, shuffle=True)
        return test_set

    else:
        raise NotImplementedError()


class MySubset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        data = self.dataset[self.indices[idx]]
        return data

    def __len__(self):
        return len(self.indices)