import numpy
import torch


class BaseDataset:

    def get_uniform_subset(self, samples_per_label):
        """
        Creates a training subset with fixed number of samples per class.
        :param samples_per_label: For each label, number of samples to choose.
        :return: The dataloader.
        """
        raise NotImplementedError

    def get_uniform_subset_dataloader(self, samples_per_label, no_grad):
        """
        Creates a dataloader of subset with fixed number of samples per class.
        :param samples_per_label: For each label, number of samples to choose.
        :param no_grad: If no_grad is true, higher batch size is used using validation_data_args
        :return: The dataloader.
        """
        uniform_subset = self.get_uniform_subset(samples_per_label)
        if no_grad:
            data_args = self.val_data_args
        else:
            data_args = self.train_data_args

        uniform_dataloader = torch.utils.data.DataLoader(uniform_subset,
                                                         batch_size=data_args['batch_size'],
                                                         shuffle=False,  # No need to shuffle
                                                         pin_memory=True,
                                                         num_workers=0)

        return uniform_dataloader

    @staticmethod
    def get_subset_targets(subset):
        targets = subset.dataset.targets
        if type(targets) == list:
            targets = numpy.array(targets)
        elif type(targets) == torch.Tensor:
            targets = targets.numpy()

        labels = targets[numpy.asarray(subset.indices)]
        return labels
