import numpy

import acquire


class SelectiveProbabilisticCoreset(acquire.ProbabilisticCoreset):

    NAME = "sp-coreset"

    def compute_and_store_scores_from_gmm(
        self, unlabelled_gmm, labelled_gmm, unlabelled_i, unlabelled_p, all_scores
    ):
        unlabelled_centers = unlabelled_gmm.means_
        assert len(unlabelled_centers) >= self._budget
        unlabelled_comp_probs = labelled_gmm.predict_proba(unlabelled_centers)
        assert  unlabelled_comp_probs.ndim == 2
        min_delta = unlabelled_comp_probs.max(axis=1)
        lowest_prob_indices = min_delta.argsort()

        unlabelled_relative_probs = unlabelled_gmm.predict_proba(unlabelled_centers)
        for i in range(1, self._budget):
            k = lowest_prob_indices[i]
            select = lowest_prob_indices[:i]

            rel_min_delta = unlabelled_relative_probs[select, k].max()
            min_delta[k] = max(rel_min_delta, min_delta[k])
        lowest_prob_indices = min_delta.argsort()

        lowest_prob_indices = lowest_prob_indices[:self._budget]
        select = unlabelled_p[:, lowest_prob_indices].T.argmax(axis=1)
        assert len(set(select)) == self._budget, str(select)
        all_scores[select] = 1





