import logging
import math

import torch
import tqdm

import acquire
import utils


class GreedyCoreset(acquire.DistanceBased):

    NAME = "greedy-coreset"

    @staticmethod
    def compute_min_deltas(z_unlabelled, z_labelled, batch_size, device):
        U, D = z_unlabelled.size()
        L, _D = z_labelled.size()
        assert D == _D

        logging.info(
            "Computing distance between {} labelled and {} unlabelled vectors of length {}...".format(
            L, U, D
        ))

        batched_min_deltas = []
        batched_min_deltas_indices = []
        batch_U = math.ceil(U/batch_size)
        batch_L = math.ceil(L/batch_size)
        with utils.Bar(range(batch_U * batch_L), desc="Computing {} x {} min deltas".format(U, L)) as bar:
            for sub_z_unlabelled in GreedyCoreset.iter_batches_of_tensor(
                batch_size, z_unlabelled
            ):
                total_min_deltas = None
                total_min_deltas_indices = None
                total_processed = 0
                for sub_z_labelled in GreedyCoreset.iter_batches_of_tensor(
                    batch_size, z_labelled
                ):
                    bar.update()

                    sub_deltas = (
                        sub_z_unlabelled.unsqueeze(1).to(device)
                        -
                        sub_z_labelled.unsqueeze(0).to(device)
                    ).norm(p=2, dim=2)
                    min_sub_deltas, local_min_sub_deltas_indices = sub_deltas.min(dim=1)
                    global_min_sub_deltas_indices = local_min_sub_deltas_indices + total_processed

                    total_processed += len(sub_z_labelled)

                    if total_min_deltas is None:
                        total_min_deltas = min_sub_deltas
                        total_min_deltas_indices = global_min_sub_deltas_indices
                    else:
                        is_better = min_sub_deltas < total_min_deltas
                        total_min_deltas[is_better] = min_sub_deltas[is_better]
                        total_min_deltas_indices[is_better] = global_min_sub_deltas_indices[is_better]

                assert total_min_deltas is not None
                assert total_processed == L
                assert len(total_min_deltas.shape) == 1
                batched_min_deltas.append(total_min_deltas.cpu())
                batched_min_deltas_indices.append(total_min_deltas_indices.cpu())

        min_deltas = torch.cat(batched_min_deltas, dim=0)
        min_deltas_indices = torch.cat(batched_min_deltas_indices, dim=0)
        assert len(min_deltas.shape) == 1 and len(min_deltas) == U
        assert len(min_deltas_indices.shape) == 1 and len(min_deltas_indices) == U

        indices = torch.arange(U, dtype=torch.long)
        return (
            U,
            L,
            D,
            min_deltas,
            min_deltas_indices,
            indices
        )

    @staticmethod
    def iter_batches_of_tensor(batch_size, X):
        dataloader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(X), batch_size=batch_size
        )
        for (x,) in dataloader:
            yield x

    def score_zs(self, z_unlabelled, z_labelled):
        """Return a tensor of length z_unlabelled that represents each score."""
        (
            U,
            L,
            D,
            min_deltas,
            _,
            indices
        ) = GreedyCoreset.compute_min_deltas(
            z_unlabelled, z_labelled, self._batch_size, self._device
        )

        logging.info("Searching for coresets greedily...")

        mask = torch.zeros(U, dtype=torch.long).to(self._device)
        min_deltas = min_deltas.to(self._device)
        indices = indices.to(self._device)
        z_unlabelled = z_unlabelled.to(self._device)

        with utils.Bar(range(self._budget), desc="Computing coresets") as bar:
            for _ in bar:  # Overall O(B*U)
                i = min_deltas.argmax().item()  # O(U)
                mask[indices[i]] = 1
                z = z_unlabelled[i:i+1]

                z_unlabelled = self._splice(z_unlabelled, i)  # O(U)
                indices = self._splice(indices, i)
                min_deltas = self._splice(min_deltas, i)

                min_deltas = self._update_min_deltas(min_deltas, z, z_unlabelled)

        assert mask.sum() == self._budget
        return mask

    def _splice(self, X, i):
        return torch.cat([X[:i], X[i+1:]], dim=0)

    def _update_min_deltas(self, min_deltas, z, z_all):
        new_delta = (z - z_all).norm(p=2, dim=1)  # O(U * D)
        deltas = torch.stack([min_deltas, new_delta], dim=1)
        min_delta, _ = deltas.min(dim=1)
        return min_delta