import numpy as np
from tqdm import trange
from impact_search import *
from utils import *

def myopic_search(labels, distances, budget=1000, K=10, influence_dict = None,
                  logits_orig= None, poisoned_label = None, random_seed=None):
    """
    Myopic active search (nearest-neighbor mean label).
    Returns the cumulative sum of positives at each step.
    """
    if random_seed is not None:
        np.random.seed(random_seed)

    N = distances.shape[0]
    selected = np.zeros(N, dtype=bool)
    nearest_neighbors = np.zeros((N, K), dtype=int)
    nearest_distances = distances[:, :K].copy()
    max_distance = np.max(nearest_distances, axis=1)
    initial_selection = np.random.choice(N, size=K, replace=False).tolist()
    selected_indices = initial_selection.copy()
    selected[initial_selection] = True
    nearest_neighbors[:, :] = initial_selection
    predictions = np.full(N, np.mean(labels[initial_selection]))
    predictions[initial_selection] = -1.0
    selected_labels = list(labels[initial_selection])
    found_influences = [influence_dict[i] if labels[i] == 1 else 0.0 for i in initial_selection]

    steps = budget - K

    for _ in trange(steps, desc='Myopic search'):
        max_pred = np.max(predictions)
        candidates = np.where(predictions == max_pred)[0]
        next_image = np.random.choice(candidates)
        selected[next_image] = True
        predictions[next_image] = -1.0
        selected_labels.append(labels[next_image])
        selected_indices.append(next_image)
        found_influences.append(influence_dict[next_image] if labels[next_image] == 1 else 0.0)

        for i in range(N):
            if not selected[i] and distances[next_image, i] < max_distance[i]:
                k = 0
                while nearest_distances[i, k] < max_distance[i]:
                    k += 1
                    if k >= K: break
                if k < K:
                    nearest_neighbors[i, k] = next_image
                    nearest_distances[i, k] = distances[next_image, i]
                    max_distance[i] = np.max(nearest_distances[i])
                    predictions[i] = np.mean(labels[nearest_neighbors[i]])

    prob_trace = estimate_prob(logits_orig, poisoned_label, labels, found_influences)
    remaining_indices = list(set(range(len(labels))) - set(selected_indices))
    return np.cumsum(selected_labels), np.cumsum(found_influences), prob_trace, remaining_indices

def myopic_influence_search(labels, distances, budget=1000, K=10, influence_dict=None, 
                            logits_orig=None, poisoned_label = None, random_seed=None):
    N = len(labels)
    influence_scores = np.array([influence_dict.get(i, 0.0) for i in range(N)])
    np.random.seed(random_seed)
    
    # ---- Initial selection: top K influence scores ---
    initial = initial_selection_by_influence(labels, influence_scores, K)
    
    selected = set(initial)
    observed_labels = {i: labels[i] for i in initial}
    selected_labels = [labels[i] for i in initial]
    print(f"Initial Selection: {initial} - {selected_labels}")

    predictions = np.zeros(N)

    found_influences = [influence_scores[i] if labels[i] == 1 else 0.0 for i in initial]
    min_val, max_val = influence_scores.min(), influence_scores.max()
    if max_val > min_val:
        normalized_influence = (influence_scores - min_val) / (max_val - min_val + 1e-8)
    else:
        normalized_influence = np.zeros(N)

    for i in range(N):
        neighbors = sorted(selected, key=lambda j: distances[i, j])[:K]
        predictions[i] = np.mean([observed_labels[j] for j in neighbors]) if neighbors else 0.0
        if i in selected:
            predictions[i] = -1.0

    steps = budget - K
    for _ in trange(steps, desc="Expected Influence Search"):
        candidates = [i for i in range(N) if i not in selected]
        if not candidates:
            break
        scores = predictions[candidates] * normalized_influence[candidates]
        max_score = np.max(scores)
        tie_candidates = [c for c, s in zip(candidates, scores) if np.isclose(s, max_score)]
        i_star = np.random.choice(tie_candidates)
        selected.add(i_star)
        selected_labels.append(labels[i_star])
        observed_labels[i_star] = labels[i_star]
        predictions[i_star] = -1.0
        found_influences.append(influence_scores[i_star] if labels[i_star] == 1 else 0.0)
        for i in range(N):
            if i in selected:
                continue
            neighbors = sorted(selected, key=lambda j: distances[i, j])[:K]
            if neighbors:
                predictions[i] = np.mean([observed_labels[j] for j in neighbors])

    prob_trace = estimate_prob(logits_orig, poisoned_label, labels, found_influences)
    remaining_indices = list(set(range(N)) - selected)
    return np.cumsum(selected_labels), np.cumsum(found_influences), prob_trace, remaining_indices

def two_step_search(labels, distances, budget=1000, K=10,
                    influence_dict=None, logits_orig=None, poisoned_label=None,
                    random_seed=None):
    
    ACTIVE_SEARCH_PRUNING = True
    TWO_STEP_LOOKAHEAD = True

    """
    Two-step search using ONLY probabilities (no influence in scoring).
    - Tie-breaking is RANDOM only.
    - Signature & returns match two_step_influence_search / myopic_search:
      returns -> np.cumsum(selected_labels), np.cumsum(found_influences), prob_trace, remaining_indices
    """
    if random_seed is not None:
        np.random.seed(random_seed)

    NEGATIVE_INFINITY = -float('inf')

    N = distances.shape[0]
    assert distances.shape == (N, N)

    influence = np.zeros(N, dtype=float)
    infl_indices = np.array(list(influence_dict.keys()), dtype=int)
    values = np.array(list(influence_dict.values()), dtype=float)
    influence[infl_indices] = values

    # precompute the increase/decrease in expected poison influence due to labeling a single neighbor
    probability_adjustment = 1.0 / float(K)  

    # ---- Initial selection: truly rsandom K seeds ----
    initial_selection = np.random.choice(N, size=K, replace=False).tolist()

    selected = np.zeros(N, dtype=bool)
    selected[initial_selection] = True
    selected_indices = initial_selection.copy()
    selected_labels = list(labels[initial_selection])
    poison_influences = list(labels[initial_selection] * influence[initial_selection])

    # 가장 먼 이웃 정보
    neighbors = np.tile(initial_selection, (N, 1)) # K nearest selected points (i.e., neighbors) for each point
    neighbor_distances = distances[:, initial_selection] # distances to the neighbors of each point
    farthest_neighbor = neighbors[np.arange(N), np.argmax(neighbor_distances, axis=1)] # farthest neighbor of each point
    farthest_neighbor_distance = distances[np.arange(N), farthest_neighbor] # distance to the farthest neighbor
    farthest_neighbor_label = labels[farthest_neighbor] # label of the farthest neighbor

    # 확률 = 선택된 K개 이웃의 라벨 평균(oracle for selected neighbors)
    probability = np.full(N, sum(labels[initial_selection]) / K) # poison probability for each point
    poison_probability = np.copy(probability)
    poison_probability[initial_selection] = NEGATIVE_INFINITY

    # impact set 구성: i가 추가되면 j의 farthest가 바뀌는 후보들
    impact_mask = distances < farthest_neighbor_distance
    impact_mask[:, initial_selection] = False # exclude selected points
    np.fill_diagonal(impact_mask, False) # points do not impact themselves
    impact_set = [np.nonzero(row)[0] for row in impact_mask]

    # 백업 버퍼(가설 업데이트 시 사용)
    backup_poison_probability = np.copy(poison_probability)

    for n in trange(min(budget, N) - len(initial_selection), desc="Two-Step Active Search by Probability"):
        subsequent_steps = min(budget, N) - len(initial_selection) - n - 1

        if subsequent_steps > 0:
            selected_point = None
            highest_total_impact = NEGATIVE_INFINITY
            
            # (이하 계산은 모두 poison_probability와 probability_adjustment만을 사용)
            if ACTIVE_SEARCH_PRUNING:
                if TWO_STEP_LOOKAHEAD:
                    ub_nonmyopic_impact_if_clean = poison_probability.max()
                    ub_nonmyopic_impact_if_poisoned = np.max(poison_probability + probability_adjustment)
                else:
                    ub_nonmyopic_impact_if_clean = np.sum(np.partition(poison_probability, -subsequent_steps)[-subsequent_steps:])
                    ub_nonmyopic_impact_if_poisoned = np.sum(np.partition(poison_probability + probability_adjustment, -subsequent_steps)[-subsequent_steps:])

            for i in range(N):
                if not selected[i]:
                    if ACTIVE_SEARCH_PRUNING:
                        ub_total_impact = poison_probability[i] \
                                          + probability[i] * ub_nonmyopic_impact_if_poisoned \
                                          + (1 - probability[i]) * ub_nonmyopic_impact_if_clean
                        if ub_total_impact < highest_total_impact:
                            continue
                    
                    temp_impact_set = impact_set[i][~selected[impact_set[i]]]

                    backup_poison_probability[i] = poison_probability[i]
                    backup_poison_probability[temp_impact_set] = poison_probability[temp_impact_set]
                    poison_probability[i] = NEGATIVE_INFINITY

                    # Case 1: clean
                    poison_probability[temp_impact_set] -= farthest_neighbor_label[temp_impact_set] * probability_adjustment
                    nonmyopic_impact_if_clean = poison_probability.max()
                    
                    poison_probability[i] = backup_poison_probability[i]
                    poison_probability[temp_impact_set] = backup_poison_probability[temp_impact_set]
                    
                    # Case 2: poisoned
                    poison_probability[temp_impact_set] += probability_adjustment
                    nonmyopic_impact_if_poisoned = poison_probability.max()
                        
                    poison_probability[i] = backup_poison_probability[i]
                    poison_probability[temp_impact_set] = backup_poison_probability[temp_impact_set]

                    total_impact = poison_probability[i] \
                                   + probability[i] * nonmyopic_impact_if_poisoned \
                                   + (1 - probability[i]) * nonmyopic_impact_if_clean
                    if total_impact > highest_total_impact:
                        highest_total_impact = total_impact
                        selected_point = i
        else:
            selected_point = np.argmax(poison_probability)

        selected[selected_point] = True
        poison_influences.append(labels[selected_point] * influence[selected_point])
        selected_labels.append(labels[selected_point])
        poison_probability[selected_point] = NEGATIVE_INFINITY

        for i in impact_set[selected_point]:
            if selected[i]: continue
            k = np.argmax(neighbor_distances[i])
            if distances[selected_point, i] < neighbor_distances[i, k]:
                neighbors[i, k] = selected_point
                neighbor_distances[i, k] = distances[selected_point, i]
                farthest_idx = np.argmax(neighbor_distances[i])
                farthest_neighbor[i] = neighbors[i][farthest_idx]
                farthest_neighbor_distance[i] = neighbor_distances[i, farthest_idx]
                farthest_neighbor_label[i] = labels[farthest_neighbor[i]]
                probability[i] = sum(labels[neighbors[i]]) / K
                poison_probability[i] = probability[i]

    prob_trace = estimate_prob(logits_orig, poisoned_label, labels, poison_influences)
    remaining_indices = list(set(range(N)) - set(selected_indices))

    return np.cumsum(selected_labels), np.cumsum(poison_influences), prob_trace, remaining_indices

def two_step_influence_search(labels, distances, budget=1000, K=10,
                              influence_dict=None, logits_orig=None, poisoned_label=None,
                              random_seed=None):
    """
    Same signature & returns:
      -> np.cumsum(selected_labels), np.cumsum(found_influences), prob_trace, remaining_indices

    Changes from previous version:
      - Initial K seeds are selected *randomly* (not influence-based)
      - Tie-breaking prefers higher influence
    """
    if random_seed is not None:
        np.random.seed(random_seed)
    
    NEGATIVE_INFINITY = -float('inf')
    ACTIVE_SEARCH_PRUNING = True
    TWO_STEP_LOOKAHEAD = True

    N = distances.shape[0]
    assert distances.shape == (N, N)

    influence = np.zeros(N, dtype=float)
    infl_indices = np.array(list(influence_dict.keys()), dtype=int)
    values = np.array(list(influence_dict.values()), dtype=float)
    influence[infl_indices] = values

    # precompute the increase/decrease in expected poison influence due to labeling a single neighbor
    influence_adjustment = influence * (1 / K)    

    # ---- Initial selection: top K influence scores ---
    initial_selection = initial_selection_by_influence(labels, influence, K)
    selected_indices = initial_selection

    poison_influences = list(labels[initial_selection] * influence[initial_selection]) # poison influences of the selected points (in order of selection)
    selected_labels = list(labels[initial_selection])
    selected = np.zeros(N, dtype=bool) # binary indicator for each point having been selected or not
    selected[initial_selection] = True

    neighbors = np.tile(initial_selection, (N, 1)) # K nearest selected points (i.e., neighbors) for each point
    neighbor_distances = distances[:, initial_selection] # distances to the neighbors of each point
    farthest_neighbor = neighbors[np.arange(N), np.argmax(neighbor_distances, axis=1)] # farthest neighbor of each point
    farthest_neighbor_distance = distances[np.arange(N), farthest_neighbor] # distance to the farthest neighbor
    farthest_neighbor_label = labels[farthest_neighbor] # label of the farthest neighbor

    probability = np.full(N, sum(labels[initial_selection]) / K) # poison probability for each point
    expected_poison_influence = probability * influence # expected poison influence for each point
    expected_poison_influence[initial_selection] = NEGATIVE_INFINITY # special value for selected points

    # for each point i, find the set of points to which point i is closer than their farthest neighbor
    impact_mask = distances < farthest_neighbor_distance
    impact_mask[:, initial_selection] = False # exclude selected points
    np.fill_diagonal(impact_mask, False) # points do not impact themselves
    impact_set = [np.nonzero(row)[0] for row in impact_mask]

    ## allocate an array to back-up values that will be overwritten during hypothetical selections
    backup_expected_poison_influence = np.copy(expected_poison_influence)

    for n in trange(min(budget, N) - K):
        subsequent_steps = min(budget, N) - K - n - 1 # number of steps left after this one

        if subsequent_steps > 0:
            selected_point = None
            highest_total_impact = NEGATIVE_INFINITY

            if ACTIVE_SEARCH_PRUNING:
                if TWO_STEP_LOOKAHEAD:
                    ub_nonmyopic_impact_if_clean = expected_poison_influence.max()
                    ub_nonmyopic_impact_if_poisoned = np.max(expected_poison_influence + influence_adjustment)
                else: # efficient non-myopic active search
                    # upper bound on the non-myopic impact when a point is discovered to be clean (i.e., not poisoned)
                    ub_nonmyopic_impact_if_clean = np.sum(np.partition(expected_poison_influence, -subsequent_steps)[-subsequent_steps:])
                    # upper bound on the non-myopic impact when a point is discovered to be poisoned
                    ub_nonmyopic_impact_if_poisoned = np.sum(np.partition(expected_poison_influence + influence_adjustment, -subsequent_steps)[-subsequent_steps:])

            # calculate the expected impact of a hypothetical selection for each point
            for i in range(N):
                if not selected[i]:
                    if ACTIVE_SEARCH_PRUNING:
                        # pruning: calculate an upper bound on the total impact for point i
                        ub_total_impact = expected_poison_influence[i] \
                                          + probability[i] * ub_nonmyopic_impact_if_poisoned \
                                          + (1 - probability[i]) * ub_nonmyopic_impact_if_clean
                        # if the upper bound is lower than the current highest impact, we can skip point i
                        if ub_total_impact < highest_total_impact:
                            continue

                    # update the set of other points that would be impacted by the selection of point i
                    remaining_impact = (distances[i][impact_set[i]] < farthest_neighbor_distance[impact_set[i]]) & (~selected[impact_set[i]])
                    impact_set[i] = impact_set[i][remaining_impact]

                    # back-up the expected poison influence values for the impacted points
                    backup_expected_poison_influence[i] = expected_poison_influence[i]
                    backup_expected_poison_influence[impact_set[i]] = expected_poison_influence[impact_set[i]]
                    # exclude point i from the calculation for the non-myopic impacts
                    expected_poison_influence[i] = NEGATIVE_INFINITY

                    # first case: if point i is discovered to be clean (i.e., not poisoned) =>
                    # for each point j that would be updated if point i were investigated,
                    # if the farthest neighbor of point j is poisoned,
                    # then decrease the poison probability for point j by 1 / K (due to the hypothetical selection of point i)
                    expected_poison_influence[impact_set[i]] -= farthest_neighbor_label[impact_set[i]] * influence_adjustment[impact_set[i]]
                    # calculate the expected impact of the subsequent step (or steps)
                    if TWO_STEP_LOOKAHEAD:
                        nonmyopic_impact_if_clean = expected_poison_influence.max()
                    else: # efficient non-myopic active search approach
                        nonmyopic_impact_if_clean = np.sum(np.partition(expected_poison_influence, -subsequent_steps)[-subsequent_steps:])
                    
                    # second case: if point i is discovered to be poisoned =>
                    # we do not need to check if the farthest neighbor of point j is poisoned or not:
                    # if the farthest neighbor is poisoned, then the probability must be increased back to its original value;
                    # otherwise, increase the probability for point j by 1 / K (due to the hypothetical selection of point i)
                    expected_poison_influence[impact_set[i]] += influence_adjustment[impact_set[i]]
                    # calculate the expected impact of the subsequent step (or steps)
                    if TWO_STEP_LOOKAHEAD:
                        nonmyopic_impact_if_poisoned = expected_poison_influence.max()
                    else: # efficient non-myopic active search approach
                        nonmyopic_impact_if_poisoned = np.sum(np.partition(expected_poison_influence, -subsequent_steps)[-subsequent_steps:])

                    # restore expected poison influence values
                    expected_poison_influence[i] = backup_expected_poison_influence[i]
                    expected_poison_influence[impact_set[i]] = backup_expected_poison_influence[impact_set[i]]

                    # calculate the total impact for point i and check if it is higher than the current highest impact
                    total_impact = expected_poison_influence[i] \
                                   + probability[i] * nonmyopic_impact_if_poisoned \
                                   + (1 - probability[i]) * nonmyopic_impact_if_clean
                    if total_impact > highest_total_impact:
                        highest_total_impact = total_impact
                        selected_point = i

        else: # if this is the very last step (within our budget)
            selected_point = np.argmax(expected_poison_influence)

        # select the point with the highest non-myopic impact
        selected[selected_point] = True
        poison_influences.append(labels[selected_point] * influence[selected_point])
        selected_indices.append(selected_point)
        selected_labels.append(labels[selected_point])
        expected_poison_influence[selected_point] = NEGATIVE_INFINITY

        # update nearest neighbors and predictions
        for i in impact_set[selected_point]:
            # replace the farthest neighbor of point i
            k = np.argmax(neighbor_distances[i])
            neighbors[i, k] = selected_point
            neighbor_distances[i, k] = distances[selected_point, i]
            # find the new farthest neighbor of point i
            farthest_neighbor[i] = neighbors[i][np.argmax(neighbor_distances[i])]
            farthest_neighbor_distance[i] = distances[farthest_neighbor[i], i]
            farthest_neighbor_label[i] = labels[farthest_neighbor[i]]
            # update the poison probability and the expected poison influence of point i
            probability[i] = sum(labels[neighbors[i]]) / K
            expected_poison_influence[i] = probability[i] * influence[i]

    prob_trace = estimate_prob(logits_orig, poisoned_label, labels, poison_influences)
    remaining_indices = list(set(range(N)) - set(selected_indices))

    return np.cumsum(selected_labels), np.cumsum(poison_influences), prob_trace, remaining_indices

def ens_influence_search(labels, distances, budget=1000, K=10,
                              influence_dict=None, logits_orig=None, poisoned_label=None,
                              random_seed=None):
    """
    Same signature & returns:
      -> np.cumsum(selected_labels), np.cumsum(found_influences), prob_trace, remaining_indices

    Changes from previous version:
      - Initial K seeds are selected *randomly* (not influence-based)
      - Tie-breaking prefers higher influence
    """
    if random_seed is not None:
        np.random.seed(random_seed)

    NEGATIVE_INFINITY = -float('inf')
    ACTIVE_SEARCH_PRUNING = True
    TWO_STEP_LOOKAHEAD = False

    N = distances.shape[0]
    assert distances.shape == (N, N)

    influence = np.zeros(N, dtype=float)
    infl_indices = np.array(list(influence_dict.keys()), dtype=int)
    values = np.array(list(influence_dict.values()), dtype=float)
    influence[infl_indices] = values

    # precompute the increase/decrease in expected poison influence due to labeling a single neighbor
    influence_adjustment = influence * (1 / K)    

    # ---- Initial selection: top K influence scores ---
    initial_selection = initial_selection_by_influence(labels, influence, K)
    selected_indices = initial_selection

    poison_influences = list(labels[initial_selection] * influence[initial_selection]) # poison influences of the selected points (in order of selection)
    selected_labels = list(labels[initial_selection])
    selected = np.zeros(N, dtype=bool) # binary indicator for each point having been selected or not
    selected[initial_selection] = True

    neighbors = np.tile(initial_selection, (N, 1)) # K nearest selected points (i.e., neighbors) for each point
    neighbor_distances = distances[:, initial_selection] # distances to the neighbors of each point
    farthest_neighbor = neighbors[np.arange(N), np.argmax(neighbor_distances, axis=1)] # farthest neighbor of each point
    farthest_neighbor_distance = distances[np.arange(N), farthest_neighbor] # distance to the farthest neighbor
    farthest_neighbor_label = labels[farthest_neighbor] # label of the farthest neighbor

    probability = np.full(N, sum(labels[initial_selection]) / K) # poison probability for each point
    expected_poison_influence = probability * influence # expected poison influence for each point
    expected_poison_influence[initial_selection] = NEGATIVE_INFINITY # special value for selected points

    # for each point i, find the set of points to which point i is closer than their farthest neighbor
    impact_mask = distances < farthest_neighbor_distance
    impact_mask[:, initial_selection] = False # exclude selected points
    np.fill_diagonal(impact_mask, False) # points do not impact themselves
    impact_set = [np.nonzero(row)[0] for row in impact_mask]

    ## allocate an array to back-up values that will be overwritten during hypothetical selections
    backup_expected_poison_influence = np.copy(expected_poison_influence)

    for n in trange(min(budget, N) - K):
        subsequent_steps = min(budget, N) - K - n - 1 # number of steps left after this one

        if subsequent_steps > 0:
            selected_point = None
            highest_total_impact = NEGATIVE_INFINITY

            if ACTIVE_SEARCH_PRUNING:
                if TWO_STEP_LOOKAHEAD:
                    ub_nonmyopic_impact_if_clean = expected_poison_influence.max()
                    ub_nonmyopic_impact_if_poisoned = np.max(expected_poison_influence + influence_adjustment)
                else: # efficient non-myopic active search
                    # upper bound on the non-myopic impact when a point is discovered to be clean (i.e., not poisoned)
                    ub_nonmyopic_impact_if_clean = np.sum(np.partition(expected_poison_influence, -subsequent_steps)[-subsequent_steps:])
                    # upper bound on the non-myopic impact when a point is discovered to be poisoned
                    ub_nonmyopic_impact_if_poisoned = np.sum(np.partition(expected_poison_influence + influence_adjustment, -subsequent_steps)[-subsequent_steps:])

            # calculate the expected impact of a hypothetical selection for each point
            for i in range(N):
                if not selected[i]:
                    if ACTIVE_SEARCH_PRUNING:
                        # pruning: calculate an upper bound on the total impact for point i
                        ub_total_impact = expected_poison_influence[i] \
                                          + probability[i] * ub_nonmyopic_impact_if_poisoned \
                                          + (1 - probability[i]) * ub_nonmyopic_impact_if_clean
                        # if the upper bound is lower than the current highest impact, we can skip point i
                        if ub_total_impact < highest_total_impact:
                            continue

                    # update the set of other points that would be impacted by the selection of point i
                    remaining_impact = (distances[i][impact_set[i]] < farthest_neighbor_distance[impact_set[i]]) & (~selected[impact_set[i]])
                    impact_set[i] = impact_set[i][remaining_impact]

                    # back-up the expected poison influence values for the impacted points
                    backup_expected_poison_influence[i] = expected_poison_influence[i]
                    backup_expected_poison_influence[impact_set[i]] = expected_poison_influence[impact_set[i]]
                    # exclude point i from the calculation for the non-myopic impacts
                    expected_poison_influence[i] = NEGATIVE_INFINITY

                    # first case: if point i is discovered to be clean (i.e., not poisoned) =>
                    # for each point j that would be updated if point i were investigated,
                    # if the farthest neighbor of point j is poisoned,
                    # then decrease the poison probability for point j by 1 / K (due to the hypothetical selection of point i)
                    expected_poison_influence[impact_set[i]] -= farthest_neighbor_label[impact_set[i]] * influence_adjustment[impact_set[i]]
                    # calculate the expected impact of the subsequent step (or steps)
                    if TWO_STEP_LOOKAHEAD:
                        nonmyopic_impact_if_clean = expected_poison_influence.max()
                    else: # efficient non-myopic active search approach
                        nonmyopic_impact_if_clean = np.sum(np.partition(expected_poison_influence, -subsequent_steps)[-subsequent_steps:])
                    
                    # second case: if point i is discovered to be poisoned =>
                    # we do not need to check if the farthest neighbor of point j is poisoned or not:
                    # if the farthest neighbor is poisoned, then the probability must be increased back to its original value;
                    # otherwise, increase the probability for point j by 1 / K (due to the hypothetical selection of point i)
                    expected_poison_influence[impact_set[i]] += influence_adjustment[impact_set[i]]
                    # calculate the expected impact of the subsequent step (or steps)
                    if TWO_STEP_LOOKAHEAD:
                        nonmyopic_impact_if_poisoned = expected_poison_influence.max()
                    else: # efficient non-myopic active search approach
                        nonmyopic_impact_if_poisoned = np.sum(np.partition(expected_poison_influence, -subsequent_steps)[-subsequent_steps:])

                    # restore expected poison influence values
                    expected_poison_influence[i] = backup_expected_poison_influence[i]
                    expected_poison_influence[impact_set[i]] = backup_expected_poison_influence[impact_set[i]]

                    # calculate the total impact for point i and check if it is higher than the current highest impact
                    total_impact = expected_poison_influence[i] \
                                   + probability[i] * nonmyopic_impact_if_poisoned \
                                   + (1 - probability[i]) * nonmyopic_impact_if_clean
                    if total_impact > highest_total_impact:
                        highest_total_impact = total_impact
                        selected_point = i

        else: # if this is the very last step (within our budget)
            selected_point = np.argmax(expected_poison_influence)

        # select the point with the highest non-myopic impact
        selected[selected_point] = True
        poison_influences.append(labels[selected_point] * influence[selected_point])
        selected_indices.append(selected_point)
        selected_labels.append(labels[selected_point])
        expected_poison_influence[selected_point] = NEGATIVE_INFINITY

        # update nearest neighbors and predictions
        for i in impact_set[selected_point]:
            # replace the farthest neighbor of point i
            k = np.argmax(neighbor_distances[i])
            neighbors[i, k] = selected_point
            neighbor_distances[i, k] = distances[selected_point, i]
            # find the new farthest neighbor of point i
            farthest_neighbor[i] = neighbors[i][np.argmax(neighbor_distances[i])]
            farthest_neighbor_distance[i] = distances[farthest_neighbor[i], i]
            farthest_neighbor_label[i] = labels[farthest_neighbor[i]]
            # update the poison probability and the expected poison influence of point i
            probability[i] = sum(labels[neighbors[i]]) / K
            expected_poison_influence[i] = probability[i] * influence[i]

    prob_trace = estimate_prob(logits_orig, poisoned_label, labels, poison_influences)
    remaining_indices = list(set(range(N)) - set(selected_indices))

    return np.cumsum(selected_labels), np.cumsum(poison_influences), prob_trace, remaining_indices