import torch
from utils.datasets.imagenet_subsets import ImageNetWIDSubset, IMAGENET100_WIDS
from utils.datasets.imagenet_augmentation import get_imageNet_augmentation
from utils.datasets.paths import get_imagenet_path
from torch.utils.data import Dataset

DEFAULT_TRAIN_BATCHSIZE = 128
DEFAULT_TEST_BATCHSIZE = 128

def _generate_split(labels, num_val_per_class, num_classes):
    labels_tensor = torch.LongTensor(labels)

    train_idcs = []
    val_idcs = []

    for class_idx in range(num_classes):
        class_idcs = torch.nonzero(labels_tensor == class_idx, as_tuple=False).squeeze()
        shuffled_idcs = class_idcs[torch.randperm(len(class_idcs))]
        class_val_idcs = shuffled_idcs[:num_val_per_class]
        class_train_idcs = shuffled_idcs[num_val_per_class:]

        train_idcs.append( class_train_idcs)
        val_idcs.append( class_val_idcs)

    train_idcs = torch.cat( train_idcs)
    val_idcs = torch.cat( val_idcs)

    train_labels = labels_tensor[train_idcs]
    validation_labels = labels_tensor[val_idcs]

    for class_idx in range(num_classes):
        assert torch.sum(validation_labels == class_idx) == num_val_per_class

    print('Split generation completed')

    return train_idcs, val_idcs

class ImageNet100TrainValidationSplit(Dataset):
    def __init__(self, path, train, transform):
        self.imagenet100 =ImageNetWIDSubset(path, split='train', wids=IMAGENET100_WIDS, transform=transform)

        if train:
            self.idcs = torch.load('imagenet100_train_split.pth')
            print(f'ImageNet100 Train split - Length {len(self.idcs)}')
        else:
            self.idcs = torch.load('imagenet100_val_split.pth')
            print(f'ImageNet100 Validation split - Length {len(self.idcs)}')

        self.targets = []
        for idx in self.idcs:
            self.targets.append( self.imagenet100.targets[idx])

        self.length = len(self.idcs)

    def __getitem__(self, index):
        imagenet_idx = self.idcs[index]
        return self.imagenet100[imagenet_idx]

    def __len__(self):
        return self.length

def get_ImageNet100_trainVal(train=True, batch_size=None, shuffle=None, augm_type='none',
                    num_workers=8, size=224, config_dict=None):
    if batch_size == None:
        if train:
            batch_size = DEFAULT_TRAIN_BATCHSIZE
        else:
            batch_size = DEFAULT_TEST_BATCHSIZE

    augm_config = {}
    transform = get_imageNet_augmentation(type=augm_type, out_size=size, config_dict=augm_config)
    if not train and augm_type != 'test' and augm_type != 'none':
        print('Warning: ImageNet test set with ref_data augmentation')

    if shuffle is None:
        shuffle = train

    path = get_imagenet_path()

    if train == True:
        dataset = ImageNet100TrainValidationSplit(path, train=True, transform=transform)
    else:
        dataset = ImageNet100TrainValidationSplit(path, train=False, transform=transform)

    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=shuffle, num_workers=num_workers)

    if config_dict is not None:
        config_dict['Dataset'] = 'ImageNet100-TrainValSplit'
        config_dict['Batch out_size'] = batch_size
        config_dict['Augmentation'] = augm_config

    return loader
