from torchvision.datasets import CIFAR10, CIFAR100
import numpy as np
from config import cfg
import os
from copy import deepcopy

from utils.datasets import *


class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

class CustomCIFAR10(CIFAR10):

    def __init__(self, *args, **kwargs):
        super(CustomCIFAR10, self).__init__(*args, **kwargs)

        self.uq_idxs = np.array(range(len(self)))
        self.target_transform = None

    def __getitem__(self, item):
        img, label = super().__getitem__(item)
        uq_idx = self.uq_idxs[item]

        return img, label, uq_idx

    def __len__(self):
        return len(self.targets)


class CustomCIFAR100(CIFAR100):

    def __init__(self, *args, **kwargs):
        super(CustomCIFAR100, self).__init__(*args, **kwargs)

        self.uq_idxs = np.array(range(len(self)))
        self.target_transform = None

    def __getitem__(self, item):
        img, label = super().__getitem__(item)
        uq_idx = self.uq_idxs[item]

        return img, label, uq_idx

    def __len__(self):
        return len(self.targets)


def get_cifar_10_datasets(
    train_transforms, test_transforms, 
    common_classes=set(range(2, 9)), private_classes=set(), 
    prop_train_labels=0.6, 
    split_train_val=False, seed=0, 
    merge_mask=True
    ):

    np.random.seed(seed)

    # Init entire training set
    if cfg.DEBUG.IS_DEBUG:
        import pickle
        if cfg.DEBUG.LOAD_FILEPATH == '':
            whole_training_set = CustomCIFAR10(root=cfg.DATASETS.CIFAR_10_ROOT, transform=train_transforms, train=True)
            whole_training_set = subsample_dataset_ratio(deepcopy(whole_training_set), float(cfg.DEBUG.DS_RATIO))
            test_dataset = CustomCIFAR10(root=cfg.DATASETS.CIFAR_10_ROOT, transform=test_transforms, train=False)
            test_dataset = subsample_dataset_ratio(deepcopy(test_dataset), float(cfg.DEBUG.DS_RATIO))
            os.makedirs(cfg.DEBUG.SAVE_FILEPATH, exist_ok=True)
            with open(os.path.join(cfg.DEBUG.SAVE_FILEPATH, 'train'), 'wb') as f:
                pickle.dump(whole_training_set, f)
            with open(os.path.join(cfg.DEBUG.SAVE_FILEPATH, 'test'), 'wb') as f:
                pickle.dump(test_dataset, f)
        else:
            with open(os.path.join(cfg.DEBUG.LOAD_FILEPATH, 'train'), 'rb') as f:
                whole_training_set = pickle.load(f, encoding='latin1')
            with open(os.path.join(cfg.DEBUG.LOAD_FILEPATH, 'test'), 'rb') as f:
                test_dataset = pickle.load(f, encoding='latin1')

    else:
        whole_training_set = CustomCIFAR10(root=cfg.DATASETS.CIFAR_10_ROOT, transform=train_transforms, train=True)
        test_dataset = CustomCIFAR10(root=cfg.DATASETS.CIFAR_10_ROOT, transform=test_transforms, train=False)

    all_datasets = get_split(
        whole_training_set, test_dataset, 
        common_classes, private_classes, 
        train_transforms, test_transforms,
        prop_train_labels,
        split_train_val,
        merge_mask
    )

    return Struct(**all_datasets)


def get_cifar_100_datasets(
    train_transforms, test_transforms, 
    common_classes=set(range(20, 80)), private_classes=set(),
    prop_train_labels=0.8, 
    split_train_val=False, seed=0,
    merge_mask=True
    ):

    np.random.seed(seed)

    if cfg.DEBUG.IS_DEBUG:
        import pickle
        if cfg.DEBUG.LOAD_FILEPATH == '':
            whole_training_set = CustomCIFAR100(root=cfg.DATASETS.CIFAR_100_ROOT, transform=train_transforms, train=True)
            whole_training_set = subsample_dataset_ratio(deepcopy(whole_training_set), float(cfg.DEBUG.DS_RATIO))
            test_dataset = CustomCIFAR100(root=cfg.DATASETS.CIFAR_100_ROOT, transform=test_transforms, train=False)
            test_dataset = subsample_dataset_ratio(deepcopy(test_dataset), float(cfg.DEBUG.DS_RATIO))
            os.makedirs(cfg.DEBUG.SAVE_FILEPATH, exist_ok=True)
            with open(os.path.join(cfg.DEBUG.SAVE_FILEPATH, 'train'), 'wb') as f:
                pickle.dump(whole_training_set, f)
            with open(os.path.join(cfg.DEBUG.SAVE_FILEPATH, 'test'), 'wb') as f:
                pickle.dump(test_dataset, f)
        else:
            with open(os.path.join(cfg.DEBUG.LOAD_FILEPATH, 'train'), 'rb') as f:
                whole_training_set = pickle.load(f, encoding='latin1')
            with open(os.path.join(cfg.DEBUG.LOAD_FILEPATH, 'test'), 'rb') as f:
                test_dataset = pickle.load(f, encoding='latin1')

    else:
        whole_training_set = CustomCIFAR100(root=cfg.DATASETS.CIFAR_100_ROOT, transform=train_transforms, train=True)
        test_dataset = CustomCIFAR100(root=cfg.DATASETS.CIFAR_100_ROOT, transform=test_transforms, train=False)

    all_datasets = get_split(
        whole_training_set, test_dataset, 
        common_classes, private_classes, 
        train_transforms, test_transforms,
        prop_train_labels,
        split_train_val,
        merge_mask
    )
    return Struct(**all_datasets)


if __name__ == '__main__':

    x = get_cifar_100_datasets(None, None, split_train_val=False,
                         train_classes=range(80), prop_train_labels=0.5)

    print('Printing lens...')
    for k, v in x.items():
        if v is not None:
            print(f'{k}: {len(v)}')

    print('Printing labelled and unlabelled overlap...')
    print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
    print('Printing total instances in train...')
    print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))

    print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
    print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
    print(f'Len labelled set: {len(x["train_labelled"])}')
    print(f'Len unlabelled set: {len(x["train_unlabelled"])}')