import torchvision
import numpy as np
import os
from copy import deepcopy
from config import cfg

from .bases import MergedDatasetMask
from utils.datasets import subsample_instances, subsample_dataset_ratio

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

class ImageNetBase(torchvision.datasets.ImageFolder):

    def __init__(self, root, transform):

        super(ImageNetBase, self).__init__(root, transform)

        self.uq_idxs = np.array(range(len(self)))

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, item):

        img, label = super().__getitem__(item)
        uq_idx = self.uq_idxs[item]

        return img, int(label), int(uq_idx)

def subsample_dataset(dataset, idxs):

    imgs_ = []
    for i in idxs:
        imgs_.append(dataset.imgs[i])
    dataset.imgs = imgs_

    samples_ = []
    for i in idxs:
        samples_.append(dataset.samples[i])
    dataset.samples = samples_

    dataset.targets = np.array(dataset.targets)[idxs].tolist()
    dataset.uq_idxs = dataset.uq_idxs[idxs]

    return dataset


def subsample_classes(dataset, include_classes=list(range(1000))):

    cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
    dataset = subsample_dataset(dataset, cls_idxs)

    return dataset


def get_train_val_indices(train_dataset, val_split=0.2):

    train_classes = list(set(train_dataset.targets))

    # Get train/test indices
    train_idxs = []
    val_idxs = []
    for cls in train_classes:

        cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]

        v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
        t_ = [x for x in cls_idxs if x not in v_]

        train_idxs.extend(t_)
        val_idxs.extend(v_)

    return train_idxs, val_idxs

def merge_dataset(dataset1, dataset2):
    dataset1.samples = np.vstack((dataset1.samples, dataset2.samples))
    dataset1.targets.extend(dataset2.targets)
    dataset1.uq_idxs = np.hstack((dataset1.uq_idxs, dataset2.uq_idxs))
    return dataset1

def get_imagenet_100_datasets(
    train_transforms, test_transforms, 
    common_classes=set(range(20, 80)), private_classes=set(),
    prop_train_labels=0.8, 
    split_train_val=False, seed=0,
    merge_mask=True
    ):

    ### Init entire set ###
    whole_training_set = ImageNetBase(root=os.path.join(cfg.DATASETS.IMAGENET_100_ROOT, 'train'), transform=train_transforms)
    test_dataset = ImageNetBase(root=os.path.join(cfg.DATASETS.IMAGENET_100_ROOT, 'val'), transform=test_transforms)

    # Reset dataset
    whole_training_set.samples = np.array([(s[0], int(s[1])) for s in whole_training_set.samples])
    whole_training_set.targets = [int(s[1]) for s in whole_training_set.samples]
    whole_training_set.uq_idxs = np.array(range(len(whole_training_set)))
    whole_training_set.target_transform = None
    # Reset test set
    test_dataset.samples = np.array([(s[0], int(s[1])) for s in test_dataset.samples])
    test_dataset.targets = [s[1] for s in test_dataset.samples]
    test_dataset.uq_idxs = np.array(range(len(test_dataset)))
    test_dataset.target_transform = None


    if cfg.DEBUG.IS_DEBUG:
        import pickle
        if cfg.DEBUG.LOAD_FILEPATH == '':
            whole_training_set = subsample_dataset_ratio(deepcopy(whole_training_set), float(cfg.DEBUG.DS_RATIO))
            test_dataset = subsample_dataset_ratio(deepcopy(test_dataset), float(cfg.DEBUG.DS_RATIO))
            os.makedirs(cfg.DEBUG.SAVE_FILEPATH, exist_ok=True)
            with open(os.path.join(cfg.DEBUG.SAVE_FILEPATH, 'train'), 'wb') as f:
                pickle.dump(whole_training_set, f)
            with open(os.path.join(cfg.DEBUG.SAVE_FILEPATH, 'test'), 'wb') as f:
                pickle.dump(test_dataset, f)
        else:
            with open(os.path.join(cfg.DEBUG.LOAD_FILEPATH, 'train'), 'rb') as f:
                whole_training_set = pickle.load(f, encoding='latin1')
            with open(os.path.join(cfg.DEBUG.LOAD_FILEPATH, 'test'), 'rb') as f:
                test_dataset = pickle.load(f, encoding='latin1')

    unlabeled_classes = set(np.unique((whole_training_set.targets))) - set(private_classes)
    labeled_classes = set(common_classes).union(set(private_classes))

    # Get labelled training set which has subsampled classes, then subsample some indices from that
    train_dataset_common = subsample_classes(deepcopy(whole_training_set), include_classes=common_classes)
    train_dataset_private = subsample_classes(deepcopy(whole_training_set), include_classes=private_classes)
    subsample_indices = subsample_instances(train_dataset_common, prop_indices_to_subsample=prop_train_labels)
    train_dataset_common = subsample_dataset(deepcopy(whole_training_set), subsample_indices)
    train_dataset_labelled = merge_dataset(deepcopy(train_dataset_private), deepcopy(train_dataset_common))

    # Get unlabelled data
    unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
    train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))

    # Train split (labelled and unlabelled classes) for training
    if merge_mask:
        train_dataset_mask = MergedDatasetMask(labelled_dataset=train_dataset_labelled, unlabelled_dataset=deepcopy(train_dataset_unlabelled))
    # Split into training and validation sets
    if split_train_val:
        train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
        train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
        val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
        val_dataset_labelled_split.transform = test_transforms
        # Get test set for all classes; Either split train into train and val or use test set as val
        train_dataset_labelled = train_dataset_labelled_split
        val_dataset_labelled = val_dataset_labelled_split
    else:
        val_dataset_labelled = None
    # reset the uq_idxs
    train_dataset_labelled.uq_idxs = np.array(range(len(train_dataset_labelled)))
    train_dataset_unlabelled.uq_idxs = np.array(range(len(train_dataset_unlabelled)))
    all_datasets = {
        'train': whole_training_set,
        'train_mask': train_dataset_mask,
        'train_labelled': train_dataset_labelled,
        'train_unlabelled': train_dataset_unlabelled,
        'val': val_dataset_labelled,
        'test': test_dataset,
        'num_known_classes': len(labeled_classes),
        'num_unknown_classes': len(unlabeled_classes),
        'common_classes': common_classes,
        'private_classes': private_classes,
        'unlabeled_classes': unlabeled_classes
    }
    return Struct(**all_datasets)



if __name__ == '__main__':
    import os
    os.chdir(os.path.dirname(os.getcwd()))

    x = get_imagenet_100_datasets(None, None, split_train_val=False,
                               train_classes=range(50), prop_train_labels=0.5)

    print('Printing lens...')
    for k, v in x.items():
        if v is not None:
            print(f'{k}: {len(v)}')

    print('Printing labelled and unlabelled overlap...')
    print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
    print('Printing total instances in train...')
    print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))

    print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
    print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
    print(f'Len labelled set: {len(x["train_labelled"])}')
    print(f'Len unlabelled set: {len(x["train_unlabelled"])}')