from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from torch import Tensor
from tqdm import tqdm

from puupl.lib.utils import ConfigurationException, boolean_nested_and


def find_imbalanced_label_mask(
    uncertainties: Tensor, percentile: Optional[float],
    max_number: Optional[int], max_uncertainty: Optional[float]
) -> Tensor:
    """
    returns a boolean mask indicating which samples to pseudo-label
    given the uncertainties of currently unlabeled samples
    """

    assert percentile is not None or max_number is not None or max_uncertainty is not None

    idx = torch.argsort(uncertainties)

    n = len(idx) - 1
    if percentile is not None:
        assert 0 < percentile < 100
        n = int(len(idx) * (percentile / 100))

    if max_number is not None:
        assert max_number > 0
        n = min(n, max_number)

    if max_uncertainty is None or max_uncertainty >= uncertainties[idx[n]]:
        mask = torch.zeros_like(uncertainties).bool()
        mask[idx[:n]] = True
    else:
        mask = uncertainties <= max_uncertainty

    assert max_number is None or mask.sum() <= max_number

    return mask


def find_balanced_label_mask(
    uncertainties: Tensor, yhat: Tensor, percentile: Optional[float],
    max_number: Optional[int], max_uncertainty: Optional[float],
    pos_neg_ratio: float
) -> Tuple[Tensor, Tensor]:
    # max_number refers considers selected samples from both classes
    assert percentile is not None or max_number is not None or max_uncertainty is not None

    # pos_neg_ratio = 2 -> pos_scale = 2/3
    # pos_neg_ratio = 0.5 -> pos_scale = 1/3
    pos_scale = pos_neg_ratio / (1 + pos_neg_ratio)

    # find maximum number of samples we can label
    n = len(uncertainties)
    if percentile is not None:
        n = min(n, int(np.round(len(uncertainties) * percentile / 100)))
    if max_number is not None:
        n = min(n, max_number)

    # distribute the number of labels according to the given ratio
    npos = int(np.round(n * pos_scale))
    nneg = int(np.round(n * (1 - pos_scale)))

    # find and sort uncertainties of predicted positives and negatives
    pos_unc = uncertainties[yhat >= 0.5]
    neg_unc = uncertainties[yhat < 0.5]

    pos_idx = torch.argsort(pos_unc)
    neg_idx = torch.argsort(neg_unc)

    if max_uncertainty is not None:
        # if we are given a maximum uncertainty, check if we need to reduce the number
        # of samples
        if len(pos_idx) == 0 or pos_unc[pos_idx[0]] > max_uncertainty \
                or len(neg_idx) == 0 or neg_unc[neg_idx[0]] > max_uncertainty:
            # no predicted samples for a class, or one or both classes have no
            # samples with uncertainty lower than the maximum.
            # in this case we do not pseudo-label anything
            npos = nneg = 0
        elif (
                npos >= len(pos_unc) or pos_unc[pos_idx[npos - 1]] > max_uncertainty
                or nneg >= len(neg_unc) or neg_unc[neg_idx[nneg - 1]] > max_uncertainty
        ):
            # the number of samples requested is larger than the number of samples
            # with uncertainty lower than the threshold for one or both classes.

            n_good_pos = int(torch.sum(pos_unc <= max_uncertainty).item())
            n_good_neg = int(torch.sum(neg_unc <= max_uncertainty).item())

            if n_good_pos / n_good_neg <= pos_neg_ratio:
                npos = n_good_pos
                nneg = int(np.round(npos / pos_neg_ratio))
            else:
                nneg = n_good_neg
                npos = int(np.round(nneg * pos_neg_ratio))

    pos_mask = torch.zeros_like(pos_unc).bool()
    neg_mask = torch.zeros_like(neg_unc).bool()

    pos_mask[pos_idx[:npos]] = True
    neg_mask[neg_idx[:nneg]] = True

    assert max_number is None or pos_mask.sum() + neg_mask.sum() <= max_number

    expanded_pos_mask = boolean_nested_and(yhat >= 0.5, pos_mask)
    expanded_neg_mask = boolean_nested_and(yhat < 0.5, neg_mask)

    return expanded_pos_mask, expanded_neg_mask


def find_balanced_unlabel_mask(
    pseudolabeled_mask: Tensor, uncertainties: Tensor, yhat: Tensor,
    min_uncertainty: Optional[float], pos_neg_ratio: float
) -> Tensor:

    if min_uncertainty is None:
        return torch.zeros_like(yhat).bool()

    ulpos = pseudolabeled_mask & (yhat >= 0.5) & (uncertainties >= min_uncertainty)
    ulneg = pseudolabeled_mask & (yhat < 0.5) & (uncertainties >= min_uncertainty)
    pos_unc, neg_unc = uncertainties[ulpos], uncertainties[ulneg]
    npos, nneg = len(pos_unc), len(neg_unc)

    if npos == 0 or nneg == 0:
        # no positives or negatives available to unlabel
        return torch.zeros_like(yhat).bool()

    if npos / nneg <= pos_neg_ratio:
        nneg = int(np.round(npos / pos_neg_ratio))
    else:
        npos = int(np.round(nneg * pos_neg_ratio))

    assert 0 <= npos <= len(pos_unc)
    assert 0 <= nneg <= len(neg_unc)

    if npos > 0 and nneg > 0:
        pos_thr: np.array = sorted(pos_unc.numpy())[-npos]
        neg_thr: np.array = sorted(neg_unc.numpy())[-nneg]
        return (ulpos & (uncertainties >= pos_thr)) | (ulneg & (uncertainties >= neg_thr))
    else:
        # desired ratio is too skewed and cannot be achieved
        return torch.zeros_like(yhat).bool()


class PseudoLabeler:
    """Select pl_size pseudo labels after each train epoch
    """

    def __init__(self,
                 max_new_labels: Optional[int] = 100,
                 new_labels_uncertainty_percentile: Optional[int] = None,
                 new_labels_max_uncertainty: Optional[float] = None,
                 new_labels_pos_neg_ratio: Optional[float] = None,
                 unlabel_min_uncertainty: Optional[float] = None,
                 use_soft_labels: bool = True,
                 reassign_pseudo_labels: bool = True):

        self.max_new_labels = max_new_labels
        self.new_labels_uncertainty_percentile = new_labels_uncertainty_percentile
        self.new_labels_max_uncertainty = new_labels_max_uncertainty
        self.unlabel_min_uncertainty = unlabel_min_uncertainty
        self.new_labels_pos_neg_ratio = new_labels_pos_neg_ratio
        self.use_soft_labels = use_soft_labels
        self.reassign_pseudo_labels = reassign_pseudo_labels

        if (max_new_labels is None
                and new_labels_uncertainty_percentile is None
                and new_labels_max_uncertainty is None):
            raise ConfigurationException(
                'must specify at least one of max_new_labels, '
                'new_labels_uncertainty_percentile, new_labels_max_uncertainty'
            )

    def pseudolabel(self, y: Tensor, yhat: Tensor, unc: Tensor, p: Tensor, l: Tensor
                    ) -> None:
        """
        Modify (in-place) labels based on the predictions.

        y: current labels
        yhat: predictions (one per data point)
        unc: uncertainty in the predictions of each data point
        p: boolean indicator for original positive samples
        l: boolean indicator for pseudo-labeled samples
        """
        assert len(yhat.shape) == len(unc.shape) == 1
        # catch gpu case, pseudo label on the cpu
        yhat = yhat.cpu()
        unc = unc.cpu()

        # find which samples to pseudo-label (pl) and which to pseudo-unlabel (ul)
        if self.new_labels_pos_neg_ratio is None:
            pl, ul = self._imbalanced_pl(unc, y, yhat, p, l)
        else:
            pl, ul = self._balanced_pl(unc, y, yhat, p, l)

        assert not torch.any(pl & p)
        assert not torch.any(ul & p)

        # assign pseudo-labels
        preds = torch.mean(yhat, dim=0) if len(yhat.shape) > 1 else yhat
        y[pl] = (preds[pl] > 0.5).float() if not self.use_soft_labels else preds[pl]
        l[pl] = True

        # remove pseudo-labels
        y[ul] = 0.0
        l[ul] = False

        tqdm.write(f'Pseudo-labeled {torch.sum(pl)} samples and unlabeled {torch.sum(ul)}')
        tqdm.write(f'total {torch.sum(p | l)} (pseudo) labelled samples now')

    def _imbalanced_pl(self, unc: Tensor, y: Tensor, yhat: Tensor, p: Tensor, l: Tensor
                       ) -> Tuple[Tensor, Tensor]:
        """
        find new pseudo-labels to assign and pseudo-labels to remove
        regardless of class predictions
        """

        u = ~p & ~l
        pl = find_imbalanced_label_mask(
            unc[u], self.new_labels_uncertainty_percentile, self.max_new_labels,
            self.new_labels_max_uncertainty
        )

        pl = boolean_nested_and(u, pl)

        if self.reassign_pseudo_labels:
            pl |= l

        if self.unlabel_min_uncertainty is not None:
            ul = l & (unc > self.unlabel_min_uncertainty)
        else:
            ul = torch.zeros_like(y).bool()

        return pl, ul

    def _balanced_pl(self, unc: Tensor, y: Tensor, yhat: Tensor, p: Tensor, l: Tensor
                     ) -> Tuple[Tensor, Tensor]:
        """
        find new pseudo-labels to assign and pseudo-labels to remove
        keeping the pseudo-labeling balanced between classes
        """
        assert self.new_labels_pos_neg_ratio is not None

        u = ~p & ~l

        pos_mask, neg_mask = find_balanced_label_mask(
            unc[u], yhat[u], self.new_labels_uncertainty_percentile,
            self.max_new_labels, self.new_labels_max_uncertainty,
            self.new_labels_pos_neg_ratio
        )

        pos_mask = boolean_nested_and(u, pos_mask)
        neg_mask = boolean_nested_and(u, neg_mask)

        pl = u & (pos_mask | neg_mask)
        ul = find_balanced_unlabel_mask(l, unc, yhat, self.unlabel_min_uncertainty,
                                        self.new_labels_pos_neg_ratio)

        return pl, ul


def get_pseudolabeler(config: Optional[Dict[str, Any]]) -> Optional[PseudoLabeler]:
    if config is not None:
        return PseudoLabeler(**config)
    return None
