import logging

import numpy
import torch

import dataset
import utils


class DatapoolCreator(utils.Namable):

    def __init__(self, label_smooth):
        super().__init__()
        self._lsmooth = label_smooth

    @utils.abstract
    def get_classes(self):
        """Return the int number of classes."""

    @utils.abstract
    def get_input_dim(self):
        """Return tuple of int dimensions of input."""

    def create(self, target_n):
        """Return a Datapool instance holding the data, with only
            a fraction of the data points labelled.

        Parameters:
        ===========
        target_n: number of random samples to label.
        """
        full_dataset = self.get_full_dataset()
        total_labels = len(full_dataset)
        frac_labelled = target_n / total_labels
        assert 0 <= frac_labelled <= 1
        logging.info("Using {:.2f}% labels of the dataset ({}/{})".format(
            frac_labelled * 100, target_n, total_labels
        ))

        # This section is to address how older versions of
        # PyTorch represent booleans as bytes.
        n = len(full_dataset)
        mask = torch.zeros(n)
        indx = numpy.arange(n, dtype=numpy.int64)
        numpy.random.shuffle(indx)
        mask[torch.from_numpy(indx[:target_n])] = 1
        mask = mask == 1

        logging.info("Randomly labelled {}/{}".format(
            mask.sum().item(), len(mask)
        ))
        return dataset.Datapool(
            self.get_name(), full_dataset, mask,
            classes=self.get_classes(),
            label_smooth=self._lsmooth
        )

    @utils.abstract
    def get_full_testset(self):
        """Return a torch.utils.data.Dataset instance, the full test set.
        """

    # === PROTECTED ===

    @utils.abstract
    def get_full_dataset(self):
        """Return a torch.utils.data.Dataset instance, the full dataset."""