"""
This files defines a ChunkDataset class. The main idea is to chunk a dataset based on the labels
E.g. In cifar-100 there are 100 classes and if chunk_size is 5, this class will create 5 chunks of 20 classes.
"""

from torch.utils.data import Dataset
from tqdm import tqdm


class ChunkDataset(Dataset):
    def __init__(self, dataset, chunk_size=10):
        self.dataset = dataset
        self.chunk_size = chunk_size
        
        self.chunks, self.class_chunk_mapping, self.chunk_class_mapping = self.make_chunks()
        
        self.chunk_keys = sorted(self.chunks.keys())
        self.chunk_idx = 0
        self.current_chunk = self.chunks[self.chunk_keys[self.chunk_idx]]

    def make_chunks(self):
        datasets = {}
        class_chunk_mapping = {}
        chunk_class_mapping = {}
        new_chunk_idx = 0

        for data, anno in tqdm(self.dataset):
            if isinstance(anno, int):
                label = anno
            else:
                raise TypeError(f"Chunking for annotation type {type(anno)} has not been implemented yet") 
            
            if label not in class_chunk_mapping:
                class_chunk_mapping[label] = new_chunk_idx
            
            chunk_idx = class_chunk_mapping[label]

            if chunk_idx not in chunk_class_mapping:
                chunk_class_mapping[chunk_idx] = set()
            chunk_class_mapping[chunk_idx].add(label)

            if chunk_idx not in datasets:
                datasets[chunk_idx] = []
            datasets[chunk_idx].append((data, label))
            
            if new_chunk_idx not in chunk_class_mapping:
                chunk_class_mapping[new_chunk_idx] = set()    

            if len(chunk_class_mapping[new_chunk_idx]) >= self.chunk_size:
                new_chunk_idx += 1

        deleted_chunk_idx = []
        for chunk_idx in chunk_class_mapping:
            if len(chunk_class_mapping[chunk_idx]) == 0:
                deleted_chunk_idx.append(chunk_idx)
        
        for chunk_idx in deleted_chunk_idx:
            del chunk_class_mapping[chunk_idx]

        return datasets, class_chunk_mapping, chunk_class_mapping

    def __len__(self):
        return len(self.current_chunk)

    def __getitem__(self, idx):        
        return self.current_chunk[idx]
        
    def next_chunk(self):
        self.chunk_idx = (self.chunk_idx + 1) % len(self.chunk_keys)
        self.current_chunk = self.chunks[self.chunk_keys[self.chunk_idx]]


if __name__ == "__main__":
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader

    CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
    ])
    
    cifar100_training = torchvision.datasets.CIFAR100(root='./src/data', train=True, download=True, transform=transform_train)
    train_dataset = ChunkDataset(cifar100_training)
       
    for i in range(5):
        cifar100_training_loader = DataLoader(train_dataset, shuffle=False, num_workers=4, batch_size=4)
        
        print(train_dataset.chunk_class_mapping)

        for data, label in cifar100_training_loader:            
            pass
        
        train_dataset.next_chunk() # This is required to go to next chunk