import logging
import random

import torch
import torch.utils.data


class Datapool(torch.utils.data.Dataset):

    def __init__(self, name, dataset, mask, classes, label_smooth, show_labelled=True, null_label=-1):
        """Instantiate a Datapool object with labeled and unlabelled data.

        Parameters:
        ===========
        name: unique str name of this dataset.
        dataset: torch.utils.data.Dataset object, len N
        mask: torch bit array of shape (N), indicates which labels are visible.
        show_labelled: bool whether to show labelled or unlabelled data. Default: True.
        null_label: int label for unlabelled data. Default: -1.
        """
        assert len(dataset) == len(mask)
        assert (label_smooth == -1) or (0 <= label_smooth <= 1)
        self._name = name
        self._dataset = dataset
        self._mask = mask.clone().long()
        self._classes = classes
        self._lsmooth = label_smooth
        self._show_labelled = show_labelled
        self._null_label = null_label
        self._give_real_indices = False

        self._mask_indices = None
        self._update_mask_indices()

    def set_give_real_indices(self, v):
        assert self._give_real_indices is (not v)
        self._give_real_indices = v

    def get_name(self):
        return self._name

    def copy(self):
        return Datapool(
            name=self._name,
            dataset=self._dataset,
            mask=self._mask,
            classes=self._classes,
            label_smooth=self._lsmooth,
            show_labelled=self._show_labelled,
            null_label=self._null_label
        )

    def count_labelled(self):
        assert self._show_labelled
        return len(self)

    def get_unlabelled_copy(self):
        return Datapool(
            name=self._name,
            dataset=self._dataset,
            # NOTE: show all data (doesn't make sense to remove the already labelled)
            # We assume a stream of constant distribution that may contain training data.
            mask=torch.ones_like(self._mask),
            classes=self._classes,
            label_smooth=0,
            show_labelled=False,
            null_label=self._null_label
        )

    def label(self, indices):
        """Label the samples pointed to by the given global indices.

        Parameters:
        ===========
        indices: torch LongTensor indices, representing indices pointing
            to the full dataset (not just the labelled or unlabelled ones).
        """
        # NOTE: allowing for relabelling
        #assert not self._mask[indices].any()
        self._mask[indices] += 1
        self._update_mask_indices()

    def __len__(self):
        return len(self._mask_indices)

    def __getitem__(self, i):
        """Get the values corresponding to the given relative index.

        Parameters:
        ===========
        i: int index to the labelled (or unlabelled) dataset.
        """
        real_i = self._mask_indices[i]
        X, real_y = self._dataset[real_i]
        y = self._apply_lsmoothing(real_y) if self._show_labelled else self._null_label

        if self._give_real_indices:
            return real_i, X, y
        else:
            return X, y

    # === PRIVATE ===

    def _apply_lsmoothing(self, y):
        return y if random.random() > self._lsmooth else random.randint(0, self._classes-1)

    def _update_mask_indices(self):
        self._mask_indices = []
        visible = 0
        redundant = 0
        for i, r in enumerate(self._mask.tolist()):
            self._mask_indices.extend([i] * r)
            visible += int(r > 0)
            redundant += max(0, r-1)

        logging.info("Showing labelled: {} ({}/{} visible, {} redundant)".format(
            self._show_labelled, visible, len(self._mask), redundant
        ))