import logging
import os

import numpy

import utils


class AcquisitionFn(utils.Namable):

    def __init__(self, budget, batch_size, device):
        super().__init__()
        self._budget = budget
        self._device = device
        self._batch_size = batch_size
        self._model = None

    def get_descriptive_name(self):
        return self.get_name()

    def init(self, model):
        out = self.__class__(*self.get_args())
        out._model = model
        return out

    def get_args(self):
        return (self._budget, self._batch_size, self._device)

    def acquire(self, datapool):
        """Mutate datapool to acquire a number of labels <= budget."""
        unlabelled_datapool = datapool.get_unlabelled_copy()

        unlabelled_datapool.set_give_real_indices(True)
        datapool.set_give_real_indices(True)

        logging.info("Found {} unlabelled features.".format(
            len(unlabelled_datapool)
        ))
        assert len(unlabelled_datapool) >= self._budget

        indices = self.acquire_unlabelled_indices(
            budget=self._budget,
            unlabelled_datapool=unlabelled_datapool,
            labelled_datapool=datapool
        )

        assert len(indices) == self._budget
        datapool.set_give_real_indices(False)
        datapool.label(indices)

    # === PROTECTED ===

    def get_model(self):
        return self._model

    @utils.abstract
    def acquire_unlabelled_indices(self, budget, unlabelled_datapool, labelled_datapool):
        """Return the budgeted number of indices pointing to the given unlabelled features.

        Parameters:
        ===========
        budget: int number of indices to acquire
        unlabelled_datapool: DataPool instance of unlabelled features.
        labelled_datapool: DataPool instance of labelled features.
        """

