import numpy as np
from torch.utils.data import DataLoader, Subset
from datasets import Imagenet
from tqdm import tqdm

class CustomDataModule():
    def __init__(self, dataset, num_workers = 1, batch_size = 32):
        self.dataset = dataset

        self.batch_size = batch_size
        self.num_workers = num_workers

        self.class_labels = np.array(dataset.labels)
        self.mask = np.full(len(self.class_labels), fill_value = 1, dtype = int)
        self.dataset = Subset(dataset, self.mask.nonzero()[0])

    def remove_class_with_residue(self, class_label, residue = 0.05):
        class_mask = self.class_labels != class_label

        inverted_class_mask = ~class_mask
        self.mask = inverted_class_mask | self.mask

        if residue > 0:
            class_indicies = np.where(inverted_class_mask == 1)[0]
            np.random.shuffle(class_indicies)

            num_of_examples_to_keep = int(len(class_indicies) * residue)
            count = 0
            for index in class_indicies:
                if num_of_examples_to_keep == count:
                    break

                if self.mask[index] == 1:
                    class_mask[index] = 1
                    count+=1
            
        self.mask = class_mask * self.mask
        indicies = self.mask.nonzero()[0]
        self.dataset = Subset(self.dataset.dataset, indicies)

    def dataloader(self, shuffle = False):
        return DataLoader(self.dataset, batch_size = self.batch_size, num_workers = self.num_workers, shuffle = shuffle)

if __name__ == '__main__':
    print('Testing Datamodule')

    dataset = Imagenet(image_dir = '/data/archived_data/progressive_data_dropout/imagenet/train')
    datamodule = CustomDataModule(dataset = dataset)

    for batch_id, (images, labels) in enumerate(datamodule.dataloader()):
        print(images)
        print(labels)
        break

    print('Number of batches before removal: {}'.format(len(datamodule.dataloader())))
    
    datamodule.remove_class_with_residue(class_label = 0, residue = 0.05)

    for batch_id, (images, labels) in enumerate(datamodule.dataloader()):
        print(images)
        print(labels)
        break

    print('Number of batches after 1 partial removal: {}'.format(len(datamodule.dataloader())))

    datamodule.remove_class_with_residue(class_label = 0, residue = 0.00)

    for batch_id, (images, labels) in enumerate(datamodule.dataloader()):
        print(images)
        print(labels)
        break

    print('Number of batches after 1 full removal: {}'.format(len(datamodule.dataloader())))

    for class_label in tqdm(range(1000)):
        datamodule.remove_class_with_residue(class_label = class_label, residue = 0.00)

    print('Number of batches after all removal: {}'.format(len(datamodule.dataloader())))