import tqdm
import torch
import torch.utils.data

import acquire
import utils


class ModelBased(acquire.AcquisitionFn):

    def acquire_unlabelled_indices(self, budget, unlabelled_datapool, labelled_datapool):
        """Return the budgeted number of indices pointing to the given unlabelled features.

        Parameters:
        ===========
        budget: int number of indices to acquire
        unlabelled_datapool: DataPool instance of unlabelled features.
        labelled_datapool: DataPool instance of labelled features.
        """
        model = self.get_model()
        model.eval()
        with torch.no_grad():
            return self._rank_unlabelled_indices_with_model(
                budget, model, unlabelled_datapool, labelled_datapool
            )

    def process_labelled_datapool(self, model, labelled_datapool):
        """Can override this to do nothing to save computations if necessary."""
        return self._process_datapool_in_batches(model, labelled_datapool, "labelled")

    @utils.abstract
    def compute_z(self, model, X):
        """Extract the necessary tensor z from X using the model."""

    @utils.abstract
    def score_zs(self, z_unlabelled, z_labelled):
        """Return a tensor of length z_unlabelled that represents each score."""

    def _rank_unlabelled_indices_with_model(self, budget, model, unlabelled_datapool, labelled_datapool):
        """Return the sorted indices pointing to the most promising unlabelled features.

        Parameters:
        ===========
        budget: int number of unlabelled points to select.
        model: model.Model instance
        unlabelled_datapool: DataPool instance of unlabelled features.
        labelled_datapool: DataPool instance of labelled features.
        """
        i_u, z_u = self._process_datapool_in_batches(model, unlabelled_datapool, "unlabelled")
        i_l, z_l = self.process_labelled_datapool(model, labelled_datapool)
        scores = self.score_zs(z_u, z_l)
        assert len(scores) == len(i_u)
        best_s, best_i = scores.float().topk(budget)
        assert len(best_s) == len(best_i) == budget
        return i_u[best_i]

    def _process_datapool_in_batches(self, model, datapool, desc):
        """Process the datapool a batch at a time using the model."""
        all_i = []
        all_z = []
        dataloader = torch.utils.data.DataLoader(
            datapool, batch_size=self._batch_size
        )
        with utils.Bar(dataloader, desc="Encoding {} data".format(desc)) as bar:
            for real_i, X, _ in bar:
                all_i.append(real_i)
                all_z.append(
                    self.compute_z(
                        model,
                        X.to(self._device)
                    ).detach().cpu()
                )
        return torch.cat(all_i), torch.cat(all_z)