import math

import numpy as np
from sklearn.cluster import KMeans
import torch


def cal_selection(labelled_idx: list, c_loss: torch.Tensor, n_p: int,
                  arg_max: torch.Tensor = None, diversity=False) -> list:
    """
    Constrained Active learning strategy.
    We take n elements which are the one that most violates the constraints
    and are among available idx

    :param labelled_idx: unavailable data (already selected)
    :param c_loss: constraint violation calculated for each point
    :param n_p: number of points to select
    :param arg_max: the number of the most violated rule for each point
    :param diversity: whether to select points based also on their diversity
    :return list of the selected idx
    """
    assert not diversity or arg_max is not None, "In case the diversity strategy " \
                                                 "is applied, arg max need to be passed"
    c_loss = c_loss.clone().detach()
    c_loss[torch.as_tensor(labelled_idx)] = -1
    cal_idx = torch.argsort(c_loss, descending=True)
    cal_idx = cal_idx[:-len(labelled_idx)]
    # cal_idx = set(cal_idx.numpy().tolist()) - set(labelled_idx)
    # cal_idx = [idx for idx in cal_idx if idx not in labelled_idx]
    if diversity:
        # max number of samples per rule 1/2 of the total number of samples
        max_p = math.ceil(n_p / 2)
        selected_idx = []
        arg_loss_dict = {}
        for i, index in enumerate(cal_idx):
            arg_loss = arg_max[index].item()
            if arg_loss in arg_loss_dict:
                # we allow to break diversity in case we have no samples available
                if arg_loss_dict[arg_loss] == max_p:
                    continue
                else:
                    arg_loss_dict[arg_loss] += 1
            else:
                arg_loss_dict[arg_loss] = 1
            selected_idx.append(index)
            if len(selected_idx) == n_p:
                break
        if len(selected_idx) < n_p:
            print("Breaking diversity")
            selected_idx = cal_selection(labelled_idx, c_loss, n_p)
        assert len(selected_idx) == n_p, "Error in the diversity " \
                                         "selection operation"
        return selected_idx

    return list(cal_idx[:n_p])


def random_selection(avail_idx: list, n_p: int) -> list:
    """
    Random Active learning strategy
    Theoretically the worst possible strategy. At each iteration
    we just take n elements randomly

    :param avail_idx: available data (not already selected)
    :param n_p: number of points to select
    :return selected idx
    """
    random_idx = np.random.choice(avail_idx, n_p).tolist()
    return random_idx


def supervised_selection(labelled_idx: list, s_loss: torch.Tensor, n_p: int) -> list:
    """
    Supervised Active learning strategy
    Possibly an upper bound to a learning strategy efficacy (fake, obviously).
    We directly select the point which mostly violates the supervision loss.

    :param labelled_idx: unavailable data (already selected)
    :param s_loss: supervision violation calculated for each point
    :param n_p: number of points to select
    :return: selected idx
    """
    s_loss = s_loss.clone().detach()
    s_loss[torch.as_tensor(labelled_idx)] = -1
    sup_idx = torch.argsort(s_loss, descending=True).tolist()[:n_p]
    return sup_idx


def uncertainty_loss(p: torch.Tensor):
    """
    We define as uncertainty a metric function for calculating the
    proximity to the boundary (predictions = 0.5).
    In order to be a proper metric function we take the opposite of
    the distance from the boundary mapped into [0,1]
    uncertainty = 1 - 2 * ||preds - 0.5||

    :param p: predictions of the network
    :return: uncertainty measure
    """
    distance = torch.abs(p - 0.5)
    if len(p.shape) > 1:
        distance = distance.mean(dim=1)
    uncertainty = 1 - 2 * distance
    return uncertainty


def uncertainty_diversity_selection(labelled_idx: list, u_loss: torch.Tensor, n_p: int,
                                    cluster_assignment: np.ndarray, max_p=None) -> list:
    """
    Uncertainty Active learning strategy
    We take n elements which are the ones on which the networks is
    mostly uncertain (i.e. the points lying closer to the decision boundaries).

    :param labelled_idx: unavailable data (already selected)
    :param u_loss: supervision violation calculated for each point
    :param n_p: number of points to select
    :param cluster_assignment: cluster to which each idx is assigned by KMeans
    :param max_p:
    :return selected idx
    """
    u_loss[torch.as_tensor(labelled_idx)] = -1
    unc_idx = torch.argsort(u_loss, descending=True)
    unc_idx = unc_idx[:-len(labelled_idx)].cpu().numpy()

    if max_p is None:
        max_p = math.floor(n_p / 2)
    selected_idx = []
    cluster_dict = {}
    for i, index in enumerate(unc_idx):
        cluster = cluster_assignment[index].item()
        if cluster in cluster_dict:
            # we allow to break diversity in case we have no samples available
            if cluster_dict[cluster] >= max_p:
                continue
            else:
                cluster_dict[cluster] += 1
        else:
            cluster_dict[cluster] = 1
        selected_idx.append(index)
        if len(selected_idx) == n_p:
            break
    assert len(selected_idx) == n_p, "Error in the diversity " \
                                     "selection operation"
    return selected_idx


def uncertainty_selection(labelled_idx: list, u_loss: torch.Tensor, n_p: int) -> list:
    """
    Uncertainty Active learning strategy
    We take n elements which are the ones on which the networks is
    mostly uncertain (i.e. the points lying closer to the decision boundaries).

    :param labelled_idx: unavailable data (already selected)
    :param u_loss: supervision violation calculated for each point
    :param n_p: number of points to select
    :return selected idx
    """
    # unc_idx = torch.argsort(u_loss, descending=True).tolist()[:n_p]
    # return unc_idx

    u_loss[torch.as_tensor(labelled_idx)] = -1
    unc_idx = torch.argsort(u_loss, descending=True)
    unc_idx = unc_idx[:-len(labelled_idx)]

    return list(unc_idx[:n_p])


SUPERVISED = "supervised"
RANDOM = "random"
CAL = "constrained"
CAL_U = "constrained_uncertain"
UNCERTAIN = "uncertainty"
UNCERTAIN_D = "uncertainty_diversity"
