import logging

import torch

import acquire
import utils


class LcBeamCoreset(acquire.DistanceLogitsBased):

    NAME = "lc-beam-coreset"

    def __init__(self, budget, batch_size, device, beam_width):
        super().__init__(budget, batch_size, device)
        assert beam_width >= 1
        self._beam_width = beam_width

    def get_descriptive_name(self):
        return "{} (beams={})".format(self.get_name(), self._beam_width)

    def get_args(self):
        return super().get_args() + (self._beam_width,)

    def score_zs(self, z_unlabelled, z_labelled):
        """Return a tensor of length z_unlabelled that represents each score."""
        K = self._beam_width
        (
            score_unlabelled,
            z_unlabelled
        ) = self.extract_hidden_features_and_scores(z_unlabelled)
        (
            score_labelled,
            z_labelled
        ) = self.extract_hidden_features_and_scores(z_labelled)

        (
            U,
            L,
            D,
            min_deltas,
            min_deltas_indices,
            indices
        ) = acquire.GreedyCoreset.compute_min_deltas(
            z_unlabelled, z_labelled, self._batch_size, self._device
        )

        score_unlabelled = self.compute_relative_unlabelled_scores(score_unlabelled, score_labelled, min_deltas_indices)
        min_deltas = self.combine_min_deltas_with_score(min_deltas, score_unlabelled)

        assert len(score_unlabelled.shape) == 1 and len(score_unlabelled) == U

        # Move everything to GPU if possible
        min_deltas = min_deltas.to(self._device)
        indices = indices.to(self._device)
        score_unlabelled = score_unlabelled.to(self._device)
        z_unlabelled = z_unlabelled.to(self._device)

        logging.info("Searching for coresets with {} beams...".format(K))

        # manually do the first beam search
        assert self._budget >= 1
        max_min_deltas, i = min_deltas.topk(K, dim=0)
        i_cand_relative = i_cand = select = indices[i]
        beam_scores = self.select_candidate_scores(max_min_deltas, score_unlabelled, select)
        assert beam_scores.size() == (K,)
        z_cand = z_unlabelled[i]
        assert z_cand.size() == (K, D)
        beam_indices = torch.arange(K, dtype=torch.long).to(self._device)

        # Mutable beam path search variables
        min_deltas = min_deltas.view(1, U).repeat(K, 1)     # size K * U
        indices = indices.view(1, U).repeat(K, 1)           # size K * U

        # Update beam path variables by splicing out current candidates
        indices = self._splice(indices, beam_indices, i_cand_relative)
        min_deltas = self._splice(min_deltas, beam_indices, i_cand_relative)

        rest_i = indices.view(-1)
        z_leftover = z_unlabelled[rest_i].view(*indices.shape, D)
        score_leftover = score_unlabelled[rest_i].view(*indices.shape)
        min_deltas = self._update_min_deltas(min_deltas, z_cand, z_leftover, score_leftover)

        with utils.Bar(range(self._budget-1), desc="Computing coresets") as bar:
            for _ in bar:
                max_min_deltas, i = min_deltas.topk(K, dim=1)  # size K * K      cost: O(K * U)

                #print(i)
                #input()
                # i is the index in indices that ultimately suggest the best K
                # candidate coreset inclusions. Since indices mutates with every
                # iteration, we must index them here instead of use i directly.
                select = indices.gather(dim=1, index=i).view(-1)
                assert select.size() == (K*K,)

                cand_scores = self.select_candidate_scores(max_min_deltas, score_unlabelled, select)
                assert cand_scores.size() == (K*K,)

                # Select the K-best beams that have the highest scores
                new_scores = beam_scores.unsqueeze(1) + cand_scores.view(K, K)
                assert new_scores.size() == (K, K)
                beam_scores, global_beam_indices = new_scores.view(-1).topk(K)
                assert global_beam_indices.size() == (K,)

                i_cand = select[global_beam_indices]
                z_cand = z_unlabelled[i_cand]
                assert z_cand.size() == (K, D)

                # Select the candidate beams, while correcting for flattening
                beam_indices = (global_beam_indices / K).long()

                # Locate the relative (mutated) index to splice out in i
                i_cand_relative = i.view(-1)[global_beam_indices]
                #print(i_cand_relative)
                #input()

                # Update beam path variables by splicing out current candidates
                indices = self._splice(indices, beam_indices, i_cand_relative)
                min_deltas = self._splice(min_deltas, beam_indices, i_cand_relative)

                rest_i = indices.view(-1)
                z_leftover = z_unlabelled[rest_i].view(*indices.shape, D)
                score_leftover = score_unlabelled[rest_i].view(*indices.shape)
                min_deltas = self._update_min_deltas(min_deltas, z_cand, z_leftover, score_leftover)

        # store some variables for plotting
        self._plotting_data = {"indices": indices, "beam_scores": beam_scores}

        mask = torch.ones(U)
        mask[indices[0].cpu()] = 0  # indices refer to unlabelled data points
        return mask == 1  # NOTE: newer PyTorch uses bool, older uses uint8

    # === PROTECTED ===

    def select_candidate_scores(self, max_min_deltas, score_unlabelled, select):
        return score_unlabelled[select]

    def compute_relative_unlabelled_scores(self, score_unlabelled, score_labelled, min_deltas_indices):
        return score_unlabelled

    def combine_min_deltas_with_score(self, min_deltas, scores):
        return min_deltas

    def score_logits(self, logits):
        """Return a 1D float tensor of scores given logits.

        Parameters:
        ===========
        logits: float tensor of size (batch, classes).
        """
        assert len(logits.shape) == 2
        v, _ = torch.nn.functional.log_softmax(logits, dim=1).max(dim=1)
        return -v

    # === PRIVATE ===

    def _splice(self, X, i, j):
        assert len(i.shape) == 1 and i.size() == j.size()
        K = len(i)
        assert len(X.shape) == 2
        _K, N = X.size()
        assert K == _K

        mask = torch.ones_like(X, dtype=torch.uint8)
        mask[torch.arange(K, dtype=torch.long).to(self._device), j] = 0
        mask = mask == 1  # NOTE: newer PyTorch uses bool, older uses uint8

        new_X = X[i][mask].view(K, N-1)
        return new_X

    def _update_min_deltas(self, min_deltas, z, z_rest, score_rest):
        K, N = min_deltas.size()
        _K1, D = z.size()
        _K2, _N, _D = z_rest.size()
        assert K == _K1 == _K2 and N == _N and D == _D
        new_delta = (z.unsqueeze(1) - z_rest).norm(p=2, dim=2)  # size K * N
        new_delta = self.combine_min_deltas_with_score(new_delta, score_rest)
        assert new_delta.size() == min_deltas.size()
        deltas = torch.stack([min_deltas, new_delta], dim=2)
        new_min_deltas, _ = deltas.min(dim=2)
        assert new_min_deltas.size() == (K, N)
        return new_min_deltas