import numpy as np
from numpy.random import default_rng

def select_knn_diversity(rng, X_train, x_i, k, l, d, lam=0.5):
    N_train = X_train.shape[0]
    train_scores = X_train @ x_i  # (N_train,)

    knn_scores = train_scores / (d-1)
    
    max_score = np.max(train_scores)
    max_indices = np.where(train_scores == max_score)[0]
    first_idx = rng.choice(max_indices)
    selected = np.array([first_idx])

    remain_mask = np.ones(N_train, dtype=bool)
    remain_mask[first_idx] = False
    G = X_train @ X_train.T  # (N_train, N_train)
    for _ in range(k - 1):
        simi_scores = G[:, selected].sum(axis=1)
        diversity_scores = (len(selected) * l - simi_scores) / (len(selected) * l)
        total_score = lam * diversity_scores + (1-lam) * knn_scores
        total_score[~remain_mask] = -float('inf')
        max_score = np.max(total_score)
        max_indices = np.where(total_score == max_score)[0]
        next_indice = rng.choice(max_indices)
        selected = np.append(selected, next_indice)
        remain_mask[next_indice] = False
    return selected

def select_knn_then_diversity(rng, X_train, x_i, k, l, d, lam=0.5):
    lam = 1
    N_train = X_train.shape[0]
    train_scores = X_train @ x_i  # (N_train,)
    knn_scores = train_scores / (l-1)
    
    max_score = np.max(train_scores)
    max_indices = np.where(train_scores == max_score)[0]
    first_idx = rng.choice(max_indices)
    selected = np.array([first_idx])

    remain_mask = np.ones(N_train, dtype=bool)
    remain_mask[first_idx] = False
    G = X_train @ X_train.T  # (N_train, N_train)
    for _ in range(k - 1):
        simi_scores = G[:, selected].sum(axis=1)
        diversity_scores = (len(selected) * l - simi_scores) / (len(selected) * l)
        total_score = lam * diversity_scores + (1-lam) * knn_scores
        total_score[~remain_mask] = -float('inf')
        max_score = np.max(total_score)
        max_indices = np.where(total_score == max_score)[0]
        next_indice = rng.choice(max_indices)
        selected = np.append(selected, next_indice)
        remain_mask[next_indice] = False
    return selected

def select_knn(rng, X_train, x_i, k):
    N_train = X_train.shape[0]
    scores = X_train @ x_i  # (N_train, )
    selected = []
    available_mask = np.ones(N_train, dtype=bool)
    
    for _ in range(k):
        masked_scores = np.where(available_mask, scores, -np.inf)
        max_score = np.max(masked_scores)
        max_indices = np.where((masked_scores == max_score))[0]
        chosen_idx = rng.choice(max_indices)
        selected.append(chosen_idx)
        available_mask[chosen_idx] = False
        
    return np.array(selected)
