import numpy as np
import torch


def select_from_scores(
    all_scores,
    elitism=False,
    epsilon=False,
    num_to_select=None,
    beta=None
):
    """Given an nxm matrix of scores, we select num_to_select individuals from the population using lexicase selection.

    Args:
        all_scores (np.ndarray): n x m matrix of (binary or integer) scores, where n is the number of individuals and m is the number of preferences.
            all_scores[i][j] is the score of the ith candidate on the jth preference.
        elitism (bool, optional): Whether to use elitism or not. If true, the individual with best overall score is always selected first. Defaults to False.
        epsilon (bool, optional): Whether to relax selection criteria from being elite to being epsilon-elite. Defaults to False.
        num_to_select (_type_, optional): Number of individuals to select. Defaults to None, where the number of individuals is set to be the same as the current popsize.
        beta (_type_, optional): If beta is a number, it will be used as the temperature parameter for probabilistic passing of preferences. If None, we use binary pass/fail lexicase. Defaults to None.

    Returns:
        np.ndarray: Indices of selected individuals in the original population.
    """
    if num_to_select is None:
        # infer num_to_select from the number of individuals in the population
        num_to_select = all_scores.shape[0]
    if num_to_select == 1 and elitism:
        print(
            "WARNING: elitism is not compatible with num_to_select=1. This will just return the best individual (i.e. not using lexicase selection)"
        )

    selected = []
    if beta is not None:
        if elitism:
            selected.append(torch.argmax(
                torch.sum(all_scores, dim=1)))  # elitism
        selected += select_from_noisy_scores(
            all_scores, num_to_select=num_to_select - (int(elitism))
        )

        return selected
    else:
        if elitism:
            selected.append(torch.argmax(
                torch.sum(all_scores, dim=1)))  # elitism

        if epsilon:
            # do lexicase selection w/ epsilon
            x_median = torch.median(all_scores, dim=1)
            # Calculate absolute deviation from median
            dev = abs(all_scores - x_median[:, None])
            mad = torch.median(dev, dim=0)
        else:
            mad = torch.zeros(all_scores.shape[1])
        mad = mad.to(torch.float32)

        for itr in range(
            # only start at 0 if not elitism
            int(elitism),
            num_to_select,
        ):  # , desc='selected', file=sys.stdout):
            num_features = all_scores.shape[1]  # 8

            # shuffle feature indices
            features = torch.randperm(num_features)
            # logical array if selected
            pool = torch.ones(all_scores.shape[0]).to(torch.bool).to(
                device=all_scores.device)
            depth = 0
            while (
                len(features) != 0 and torch.sum(pool) != 1
            ):  # while we still have cases to use
                depth += 1
                feature = features[0]
                features = features[1:]

                best = torch.max(all_scores[pool, feature]).to(torch.float32)
                old_pool = pool
                # filter selected pop with this feature. If it filters everyone, skip
                pool = torch.logical_and(
                    pool,
                    all_scores[:, feature] >= best - mad[feature],
                )
            pool = pool.to(torch.long)
            selected.append(torch.argmax(pool))

        return torch.tensor(selected)

#numpy version
def np_select_from_noisy_scores(all_scores, num_to_select: int):
    selected = []
    for itr in range(num_to_select):
        num_features = all_scores.shape[1]
        features = np.arange(num_features)
        np.random.shuffle(features)
        # logical array if selected
        pool = np.ones(all_scores.shape[0], dtype=bool)
        depth = 0
        while (
            len(features) != 0 and np.sum(pool) != 1
        ):  # while we still have cases to use
            depth += 1
            feature = features[0]
            features = features[1:]

            rand = np.random.rand(pool.shape[0]).to(all_scores.device)
            old_pool = pool

            # filter selected pop with this feature. If it filters everyone, skip
            pool = np.logical_and(
                pool,
                np.logical_or(
                    all_scores[:, feature] == 1, all_scores[:, feature] >= rand
                ),
            )

            if np.sum(pool) == 0:
                pool = old_pool
        selected.append(np.argmax(pool))
    return selected

#torch version
def select_from_noisy_scores(all_scores, num_to_select: int):
    selected = []
    for itr in range(num_to_select):
        num_features = all_scores.shape[1]
        features = torch.randperm(num_features)
        # logical array if selected
        pool = torch.ones(all_scores.shape[0], dtype=torch.bool).to(all_scores.device)
        depth = 0
        while (
            len(features) != 0 and torch.sum(pool) != 1
        ):  # while we still have cases to use
            depth += 1
            feature = features[0]
            features = features[1:]

            rand = torch.rand(pool.shape[0]).to(all_scores.device)
            old_pool = pool

            # filter selected pop with this feature. If it filters everyone, skip
            pool = torch.logical_and(
                pool,
                torch.logical_or(
                    all_scores[:, feature] == 1, all_scores[:, feature] >= rand
                ),
            )

            if torch.sum(pool) == 0:
                pool = old_pool
        selected.append(torch.argmax(pool.to(torch.long)))
    return selected