import logging

import torch
import numpy
import sklearn.mixture
import tqdm

import acquire
import utils


class ProbabilisticCoreset(acquire.DistanceBased):

    NAME = "p-coreset"

    def __init__(
        self,
        # AcquisitionFn
        budget,
        # ModelBased
        batch_size, device,
        # ProbabilisticCoreset
        gmm_components, gmm_train_verbose, min_accept_proba, mc_cycles=1, eps=1e-32
    ):
        super().__init__(budget, batch_size, device)
        self._gmm_k = gmm_components
        self._gmm_verbose = gmm_train_verbose
        self._min_accept_proba = min_accept_proba
        self._mc_cycles = mc_cycles
        self._eps = eps

    def get_args(self):
        return super().get_args() + (
            self._gmm_k,
            self._gmm_verbose,
            self._min_accept_proba,
            self._mc_cycles,
            self._eps
        )

    def score_zs(self, z_unlabelled, z_labelled):
        """Return a tensor of length z_unlabelled that represents each score."""
        z_unlabelled = z_unlabelled.numpy()
        out_size = len(z_unlabelled)
        i_unlabelled = numpy.arange(out_size, dtype=numpy.int64)
        z_labelled = z_labelled.numpy()

        # Fit gmm to z_labelled
        logging.info("Training GMM (k={}) on labelled features...".format(self._gmm_k))
        labelled_gmm = self._fit_gmm(z_labelled)

        # Use labelled GMM to score probabilities of unlabelled z to be in there
        prob_in_labelled = labelled_gmm.predict_proba(z_unlabelled).max(axis=1)
        assert prob_in_labelled.ndim == 1 and len(prob_in_labelled) == out_size

        # Use these probabilities to decide if the unlabelled z should go into next gmm fit
        accept_proba = (1-self._min_accept_proba) * prob_in_labelled
        accept = numpy.random.rand(len(prob_in_labelled)) >= accept_proba
        assert len(accept) == out_size

        # Fit GMM to these selected unlabelled z
        accepted_unlabelled_z = z_unlabelled[accept]
        accepted_unlabelled_i = i_unlabelled[accept]
        assert len(accepted_unlabelled_z) >= self._budget
        logging.info("Training GMM (k={}) on unlabelled features...".format(self._gmm_k))
        accepted_unlabelled_gmm = self._fit_gmm(accepted_unlabelled_z)
        accepted_unlabelled_proba = accepted_unlabelled_gmm.predict_proba(accepted_unlabelled_z)

        all_scores = numpy.full(out_size, fill_value=-numpy.inf, dtype=numpy.float64)
        self.compute_and_store_scores_from_gmm(
            unlabelled_gmm=accepted_unlabelled_gmm,
            labelled_gmm=labelled_gmm,
            unlabelled_i=accepted_unlabelled_i,
            unlabelled_p=accepted_unlabelled_proba,
            all_scores=all_scores
        )
        return torch.from_numpy(all_scores)

    def compute_and_store_scores_from_gmm(
        self, unlabelled_gmm, labelled_gmm, unlabelled_i, unlabelled_p, all_scores
    ):
        # Iterate random MC batches of unlabelled z, keeping track of the most promising
        best_score = -float("inf")
        best_batch = None
        logging.info("Scoring random batches...")

        for index, budgeted_batch in self._repeat_iter_batches(
            unlabelled_i, unlabelled_p
        ):
            # score modified batch-BALD given second GMM (i.e. those that have
            # low proba of being in labelled set but high proba of being in dense
            # regions in the unlabelled set)
            score = self._score(budgeted_batch[:, lowest_prob_indices])
            all_scores[index] = score

    # === PRIVATE ===

    def _score(self, probs):
        """Return a higher score for more desirable batches, averaged across batch.

        Parameters:
        ===========
        probs: numpy array of shape (batch, gmm components), probabilities.
        """
        n, c = probs.shape
        comp_probs_coverage = probs.max(axis=0)
        assert comp_probs_coverage.ndim == 1 and len(comp_probs_coverage) == c
        return numpy.log(comp_probs_coverage + self._eps).sum()/n

    def _repeat_iter_batches(self, zi, z):
        assert len(zi) == len(z)
        n = len(z)
        b = n // self._budget
        i = numpy.arange(n, dtype=numpy.long)
        with utils.Bar(range(self._mc_cycles * b)) as bar:
            for mc in range(self._mc_cycles):
                bar.set_description("Cycle {} of random batches".format(mc))
                numpy.random.shuffle(i)
                for bi in range(b):
                    select = i[bi * self._budget: (bi + 1) * self._budget]
                    bar.update()
                    yield zi[select], z[select]

    def _fit_gmm(self, z):
        gmm = sklearn.mixture.GaussianMixture(
            n_components=self._gmm_k,
            verbose=self._gmm_verbose
        )
        gmm.fit(z)
        return gmm