import math
import gc
from common_imports import torch, np, tqdm
from common_use_functions import next_power_of_2, class_index_dict_build
from scipy.ndimage import gaussian_filter1d

def estimate_dense_integer_k(similarities, verify_steps=3, variation_threshold=0.05, min_k=5, smooth_sigma=1.0, integer_step=10):
    """
    This function estimates the appropriate k value for the KNN scores.
    This function estimates the k that represents the most dense area around a data point.
    We find the turning point using the secondary derivatives.

    similarities: The knn similarities for various k values.
    verify_steps: The minimum steps for verifying the real change point.
    variation_threshold: The minimum value indicating there is no big variations.
    min_k: The minimum k to consider the variation.
    smooth_sigma: The coefficient for the smoothing.
    integer_step: The base unit for the integer conversion.
    """
    # Get the average similarities
    avg_sims = np.mean(similarities, axis=0)
    if smooth_sigma > 0:
        avg_sims = gaussian_filter1d(avg_sims, sigma=smooth_sigma)
    # Compute the second derivatives
    dy = np.gradient(avg_sims)
    d2y = np.gradient(dy)
    # Get the point where the derivatives start to stabilize
    for index in range(min_k, len(d2y) - verify_steps):
        window = d2y[index : index+verify_steps]
        current_mean = np.mean(np.abs(window))
        current_std = np.std(window)
        if current_std / (current_mean + 1e-16) < variation_threshold:
            return int((index + 1) / integer_step) * integer_step

    return int(len(avg_sims) / integer_step) * integer_step

def estimate_dense_k(similarities, verify_steps=3, variation_threshold=0.05, min_k=5, smooth_sigma=1.0):
    """
    This function estimates the appropriate k value for the KNN scores.
    This function estimates the k that represents the most dense area around a data point.
    We find the turning point using the secondary derivatives.

    similarities: The knn similarities for various k values.
    verify_steps: The minimum steps for verifying the real change point.
    variation_threshold: The minimum value indicating there is no big variations.
    min_k: The minimum k to consider the variation.
    smooth_sigma: The coefficient for the smoothing.
    """
    # Get the average similarities
    avg_sims = np.mean(similarities, axis=0)
    if smooth_sigma > 0:
        avg_sims = gaussian_filter1d(avg_sims, sigma=smooth_sigma)
    # Compute the second derivatives
    dy = np.gradient(avg_sims)
    d2y = np.gradient(dy)
    # Get the point where the derivatives start to stabilize
    for index in range(min_k, len(d2y) - verify_steps):
        window = d2y[index : index+verify_steps]
        current_mean = np.mean(np.abs(window))
        current_std = np.std(window)
        if current_std / (current_mean + 1e-16) < variation_threshold:
            return index + 1

    return len(avg_sims)

def estimate_precise_appropriate_k(similarities, min_k=10, verify_steps=3, smooth_sigma=1.0):
    """
    This function estimates the appropriate k value for the KNN scores.
    This function estimates an appropriate local area that captures all properties.

    similarities: The knn similarities for various k values.
    min_k: The minimum k to consider the variation.
    verify_steps: The minimum steps for verifying the real change point.
    smooth_sigma: The coefficient for the smoothing.
    """
    # Evaluate the variance for all the k     
    var_per_k = np.var(similarities, axis=0)
    if smooth_sigma > 0:
        var_per_k = gaussian_filter1d(var_per_k, sigma=smooth_sigma)
    # Take the k where the variance starts to decrease (already enough for capturing the local property)
    for index in range(min_k - 1, len(var_per_k) - verify_steps):
        verify_bool_list = []
        for step in range(verify_steps):
            if var_per_k[index+step] > var_per_k[index + 1 + step]:
                verify_bool_list.append(True)
            else:
                verify_bool_list.append(False)
        if all(verify_bool_list):
            return index + 1

    return len(var_per_k)

def estimate_appropriate_k(similarities, potential_k_list = [1, 10, 20, 50, 100, 200, 500, 1000]):
    """
    This function estimates the appropriate k value for the KNN scores.
    This function estimates an appropriate local area that captures all properties.

    similarities: The knn similarities for various k values.
    potential_k_list: The potential k values to consider.
    """
    # The maximum k
    k_max = max(potential_k_list)
    # Evaluate the variance for all the k     
    var_per_k = np.var(similarities[:, :k_max], axis=0)
    # Compute the real k indices
    potential_k_pos = np.array(potential_k_list)-1
    # Take the k where the variance starts to decrease (already enough for capturing the local property)
    found = False
    found_k = None
    for index in range(0, len(potential_k_list)-1):
        if var_per_k[potential_k_pos[index]] > var_per_k[potential_k_pos[index+1]]:
            found = True
            found_k = potential_k_list[index]
            break
    # Set the default k
    if not found:
        found_k = potential_k_list[-1]

    return found_k

def dot_product_gpu(first_array, second_array, transpose=True):
    """
    This functions applies the dot product with pytorch on GPU.

    first_array: The first 2D numpy array.
    second_array: The second 2D numpy array.
    transpose: If we would like to transpose the second array.
    """
    # Convert the arrays to tensors and move them to GPU
    first_tensor = torch.from_numpy(first_array).cuda()
    second_tensor = None
    if transpose:
        second_tensor = torch.from_numpy(np.transpose(second_array)).cuda()
    else:
        second_tensor = torch.from_numpy(second_array).cuda()
    # Applies the dot product
    result = torch.matmul(first_tensor, second_tensor).cpu().numpy()

    return result

def knn_sim_IP_GPU(database, query, batch_size=50, k=200, display=True, half_precision=False):
    """
    This functions applies a knn search on GPU. All the provided vectors should be already normalized.
    This function returns the similarities directly.

    database: The database for the KNN search, with a shape of (N_d, feat_dim).
    query: The query for the KNN search, with a shape of (N_q, feat_dim).
    batch_size: The number of examples to treat in each batch.
    k: The number of the K nearest neighbours.
    display: If we would like to display the progress bar.
    half_precision: If we use FP16 to accelerate more.
    """
     # Move the database to the GPU
    database_tensor = torch.from_numpy(database)
    if half_precision:
        database_tensor = database_tensor.half() 
    database_tensor = database_tensor.cuda()
    
    # Initialize the result
    nb_examples = query.shape[0]
    top_k_similarities = np.zeros((nb_examples, k), dtype=np.float32)
    top_k_indices = np.zeros((nb_examples, k), dtype=np.int32)

    # Determine the information related to the batches
    nb_batches = math.ceil(nb_examples / batch_size)

    # Build the iterate index list
    batch_progress_bar = list(range(nb_batches))
    if display:
        batch_progress_bar = tqdm(list(range(nb_batches)), desc='Processed batches')
    
    # Execute the knn search
    with torch.no_grad():
        for batch_index in batch_progress_bar:
            # Take the current batch data and move it to GPU
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_query = query[batch_start_pos:batch_end_pos]
            query_tensor = torch.from_numpy(batch_query)
            if half_precision:
                query_tensor = query_tensor.half() 
            query_tensor = query_tensor.cuda()
            
            # Compute the inner-product
            similarities = torch.mm(query_tensor, database_tensor.T)
            
            # Get the Top-k examples
            batch_top_k_sim, batch_top_k_idx = similarities.topk(k, dim=1)

            # Register the result
            top_k_similarities[batch_start_pos:batch_end_pos] = batch_top_k_sim.float().cpu().numpy()
            top_k_indices[batch_start_pos:batch_end_pos] = batch_top_k_idx.cpu().numpy().astype(np.int32)

    # Clean the memory
    del database_tensor, query_tensor, similarities
    torch.cuda.empty_cache()
    gc.collect()
    
    return top_k_similarities, top_k_indices

def knn_L2_dist_GPU(database, query, batch_size=50, k=200, display=True, half_precision=False):
    """
    This functions applies a knn search on GPU. All the provided vectors should be already normalized.
    This function returns the similarities directly.

    database: The database for the KNN search, with a shape of (N_d, feat_dim).
    query: The query for the KNN search, with a shape of (N_q, feat_dim).
    batch_size: The number of examples to treat in each batch.
    k: The number of the K nearest neighbours.
    display: If we would like to display the progress bar.
    half_precision: If we use FP16 to accelerate more.

    Note: Here, we compute the square distances.
    """
     # Move the database to the GPU
    database_tensor = torch.from_numpy(database)
    if half_precision:
        database_tensor = database_tensor.half() 
    database_tensor = database_tensor.cuda()
    
    # Initialize the result
    nb_examples = query.shape[0]
    top_k_distances = np.zeros((nb_examples, k), dtype=np.float32)
    top_k_indices = np.zeros((nb_examples, k), dtype=np.int32)

    # Determine the information related to the batches
    nb_batches = math.ceil(nb_examples / batch_size)

    # Build the iterate index list
    batch_progress_bar = list(range(nb_batches))
    if display:
        batch_progress_bar = tqdm(list(range(nb_batches)), desc='Processed batches')
    
    # Execute the knn search
    with torch.no_grad():
        for batch_index in batch_progress_bar:
            # Take the current batch data and move it to GPU
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_query = query[batch_start_pos:batch_end_pos]
            query_tensor = torch.from_numpy(batch_query)
            if half_precision:
                query_tensor = query_tensor.half() 
            query_tensor = query_tensor.cuda()

            # Compute query norms (||x||^2)
            query_norms = (query_tensor ** 2).sum(dim=1, keepdim=True)  # [batch_size, 1]
            
            # Compute database norms (||y||^2) on-the-fly in chunks to save memory
            db_batch_size = 50000  # Adjust based on GPU memory
            num_db_batches = math.ceil(database_tensor.shape[0] / db_batch_size)
            all_distances = []

            for db_batch_idx in range(num_db_batches):
                db_start = db_batch_idx * db_batch_size
                db_end = min((db_batch_idx + 1) * db_batch_size, database_tensor.shape[0])
                db_batch = database_tensor[db_start:db_end]  # [db_batch_size, feat_dim]

                # Compute db norms (||y||^2) for this chunk
                db_norms = (db_batch ** 2).sum(dim=1, keepdim=True)  # [db_batch_size, 1]

                # Compute inner product <x, y>
                similarities = torch.mm(query_tensor, db_batch.T)  # [batch_size, db_batch_size]

                # Compute L2 distances: ||x - y||^2 = ||x||^2 + ||y||^2 - 2<x,y>
                distances = query_norms + db_norms.T - 2 * similarities  # [batch_size, db_batch_size]
                distances = torch.clamp(distances, min=0)  # Ensure non-negative

                all_distances.append(distances)

            # Concatenate all distances and find top-k
            all_distances = torch.cat(all_distances, dim=1)  # [batch_size, N_d]
            batch_top_k_dist, batch_top_k_idx = all_distances.topk(k, dim=1, largest=False)
            
            # Store results
            top_k_distances[batch_start_pos:batch_end_pos] = batch_top_k_dist.float().cpu().numpy()
            top_k_indices[batch_start_pos:batch_end_pos] = batch_top_k_idx.cpu().numpy().astype(np.int32)

    # Cleanup
    del database_tensor, query_tensor, all_distances, batch_top_k_dist, batch_top_k_idx
    torch.cuda.empty_cache()
    gc.collect()
    
    return top_k_distances, top_k_indices

# def knn_L2_dist_GPU(database, query, batch_size=50, k=200, display=True, half_precision=False):
#     """
#     This functions applies a knn search on GPU. All the provided vectors should be already normalized.
#     This function returns the similarities directly.

#     database: The database for the KNN search, with a shape of (N_d, feat_dim).
#     query: The query for the KNN search, with a shape of (N_q, feat_dim).
#     batch_size: The number of examples to treat in each batch.
#     k: The number of the K nearest neighbours.
#     display: If we would like to display the progress bar.
#     half_precision: If we use FP16 to accelerate more.

#     Note: Here, we compute the square distances.
#     """
#      # Move the database to the GPU
#     database_tensor = torch.from_numpy(database)
#     if half_precision:
#         database_tensor = database_tensor.half() 
#     database_tensor = database_tensor.cuda()

#     # Precompute squared norms of database vectors
#     db_norms = (database_tensor.float() ** 2).sum(dim=1).unsqueeze(0)
#     if half_precision:
#         db_norms = db_norms.half()
    
#     # Initialize the result
#     nb_examples = query.shape[0]
#     top_k_distances = np.zeros((nb_examples, k), dtype=np.float32)
#     top_k_indices = np.zeros((nb_examples, k), dtype=np.int32)

#     # Determine the information related to the batches
#     nb_batches = math.ceil(nb_examples / batch_size)

#     # Build the iterate index list
#     batch_progress_bar = list(range(nb_batches))
#     if display:
#         batch_progress_bar = tqdm(list(range(nb_batches)), desc='Processed batches')
    
#     # Execute the knn search
#     with torch.no_grad():
#         for batch_index in batch_progress_bar:
#             # Take the current batch data and move it to GPU
#             batch_start_pos = batch_index*batch_size
#             batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
#             batch_query = query[batch_start_pos:batch_end_pos]
#             query_tensor = torch.from_numpy(batch_query)
#             if half_precision:
#                 query_tensor = query_tensor.half() 
#             query_tensor = query_tensor.cuda()

#             # Compute squared norms of query vectors (||x||^2)
#             query_norms = (query_tensor.float() ** 2).sum(dim=1, keepdim=True)
#             if half_precision:
#                 query_norms = query_norms.half()
            
#             # Compute the inner-product
#             similarities = torch.mm(query_tensor, database_tensor.T)

#             # Compute L2 distances: ||x - y||^2 = ||x||^2 + ||y||^2 - 2<x,y>
#             distances = query_norms + db_norms - 2 * similarities
#             distances = torch.clamp(distances, min=0) # To make sure when the half precision is used, we don't have negative values.
            
#             # Get top-k smallest distances
#             batch_top_k_dist, batch_top_k_idx = distances.topk(k, dim=1, largest=False)
            
#             # Store results
#             top_k_distances[batch_start_pos:batch_end_pos] = batch_top_k_dist.float().cpu().numpy()
#             top_k_indices[batch_start_pos:batch_end_pos] = batch_top_k_idx.cpu().numpy().astype(np.int32)

#     # Clean the memory
#     del database_tensor, query_tensor, db_norms, query_norms, similarities, distances, batch_top_k_dist, batch_top_k_idx
#     torch.cuda.empty_cache()
#     gc.collect()
    
#     return top_k_distances, top_k_indices

def knn_search_IP_GPU(database, query, batch_size=50, k=200, display=True, half_precision=False):
    """
    This functions applies a knn search on GPU. All the provided vectors should be already normalized.

    database: The database for the KNN search, with a shape of (N_d, feat_dim).
    query: The query for the KNN search, with a shape of (N_q, feat_dim).
    batch_size: The number of examples to treat in each batch.
    k: The number of the K nearest neighbours.
    display: If we would like to display the progress bar.
    half_precision: If we use FP16 to accelerate more.

    Note: This function dosen't return the similarities directly, it returns a normalized version which has the values within [0,1].
    """
     # Move the database to the GPU
    database_tensor = torch.from_numpy(database)
    if half_precision:
        database_tensor = database_tensor.half() 
    database_tensor = database_tensor.cuda()
    
    # Initialize the result
    nb_examples = query.shape[0]
    top_k_similarities = np.zeros((nb_examples, k), dtype=np.float32)
    top_k_indices = np.zeros((nb_examples, k), dtype=np.int32)

    # Determine the information related to the batches
    nb_batches = math.ceil(nb_examples / batch_size)

    # Build the iterate index list
    batch_progress_bar = list(range(nb_batches))
    if display:
        batch_progress_bar = tqdm(list(range(nb_batches)), desc='Processed batches')
    
    # Execute the knn search
    with torch.no_grad():
        for batch_index in batch_progress_bar:
            # Take the current batch data and move it to GPU
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_query = query[batch_start_pos:batch_end_pos]
            query_tensor = torch.from_numpy(batch_query)
            if half_precision:
                query_tensor = query_tensor.half() 
            query_tensor = query_tensor.cuda()
            
            # Compute the inner-product
            similarities = torch.mm(query_tensor, database_tensor.T)
            
            # Get the Top-k examples
            batch_top_k_sim, batch_top_k_idx = similarities.topk(k, dim=1)

            # Register the result
            top_k_similarities[batch_start_pos:batch_end_pos] = batch_top_k_sim.float().cpu().numpy()
            top_k_indices[batch_start_pos:batch_end_pos] = batch_top_k_idx.cpu().numpy().astype(np.int32)

    # Clean the memory
    del database_tensor, query_tensor, similarities, batch_top_k_sim, batch_top_k_idx
    torch.cuda.empty_cache()
    gc.collect()
    
    return (top_k_similarities + 1) / 2, top_k_indices

def knn_search_IP_GPU_by_class(database, database_labels, query, query_preds, k=200, display=True, half_precision=False):
    """
    This functions applies a knn search on GPU. All the provided vectors should be already normalized.

    database: The database for the KNN search, with a shape of (N_d, feat_dim).
    database_labels: The labels of the database examples.
    query: The query for the KNN search, with a shape of (N_q, feat_dim).
    query_preds: The predictions for the query.
    batch_size: The number of examples to treat in each batch.
    k: The number of the K nearest neighbours.
    display: If we would like to display the progress bar.
    half_precision: If we use FP16 to accelerate more.

    Note: This function dosen't return the similarities directly, it returns a normalized version which has the values within [0,1].
    """
     # Create the database tensors
    database_tensor = torch.from_numpy(database)
    if half_precision:
        database_tensor = database_tensor.half()
    database_tensor = database_tensor.cuda()    

    # Separate the database by class
    database_map_dict = class_index_dict_build(database_labels)
    database_class_tensor = {}
    for classId in database_map_dict:
        database_class_tensor[classId] = database_tensor[database_map_dict[classId]]

    # Separate the query by class
    query_map_dict = class_index_dict_build(query_preds)
    query_class_vecs = {}
    for classId in query_map_dict:
        query_class_vecs[classId] = query[query_map_dict[classId]]
    
    # Initialize the result
    nb_examples = query.shape[0]
    top_k_similarities = np.zeros((nb_examples, k), dtype=np.float32)
    top_k_indices = np.zeros((nb_examples, k), dtype=np.int32)

    # Build the iterate index list
    class_progress_bar = list(query_class_vecs.keys())
    if display:
        class_progress_bar = tqdm(list(query_class_vecs.keys()), desc='Processed batches')

    # Execute the knn search
    with torch.no_grad():
        for class_index in class_progress_bar:
            # Take the current batch data and move it to GPU
            query_tensor = torch.from_numpy(query_class_vecs[class_index])
            if half_precision:
                query_tensor = query_tensor.half() 
            query_tensor = query_tensor.cuda()
            
            # Compute the inner-product
            similarities = torch.mm(query_tensor, database_class_tensor[class_index].T)
            
            # Get the Top-k examples
            batch_top_k_sim, batch_top_k_idx = similarities.topk(k, dim=1)

            # Register the result
            top_k_similarities[query_map_dict[class_index]] = batch_top_k_sim.float().cpu().numpy()
            top_k_indices[query_map_dict[class_index]] = batch_top_k_idx.cpu().numpy().astype(np.int32) # The top-k indices are within the class now.

    # Clean the memory
    del database_tensor, database_class_tensor, query_tensor, query_class_vecs, similarities, batch_top_k_sim, batch_top_k_idx
    torch.cuda.empty_cache()
    gc.collect()
    
    return (top_k_similarities + 1) / 2, top_k_indices

def get_knn_OOD_score(similarities):
    """
    This function returns the ood scores based on the knn inner-product similarities.

    similarities: The similarities obtained after the KNN search.
    """
    return (similarities[:, -1].reshape(-1) + 1) / 2

def knn_scores_IP_GPU(database, query, batch_size=50, k=200, display=True, half_precision=False):
    """
    This functions applies a knn search on GPU. All the provided vectors should be already normalized.

    database: The database for the KNN search, with a shape of (N_d, feat_dim).
    query: The query for the KNN search, with a shape of (N_q, feat_dim).
    batch_size: The number of examples to treat in each batch.
    k: The number of the K nearest neighbours.
    display: If we would like to display the progress bar.
    half_precision: If we use FP16 to accelerate more.
    """
    # Move the database to the GPU
    database_tensor = torch.from_numpy(database)
    if half_precision:
        database_tensor = database_tensor.half() 
    database_tensor = database_tensor.cuda()
    
    # Initialize the result
    nb_examples = query.shape[0]
    top_k_similarities = np.zeros((nb_examples, k), dtype=np.float32)
    top_k_indices = np.zeros((nb_examples, k), dtype=np.int32)

    # Determine the information related to the batches
    nb_batches = math.ceil(nb_examples / batch_size)

    # Build the iterate index list
    batch_progress_bar = list(range(nb_batches))
    if display:
        batch_progress_bar = tqdm(list(range(nb_batches)), desc='Processed batches')
    
    # Execute the knn search
    with torch.no_grad():
        for batch_index in batch_progress_bar:
            # Take the current batch data and move it to GPU
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_query = query[batch_start_pos:batch_end_pos]
            query_tensor = torch.from_numpy(batch_query)
            if half_precision:
                query_tensor = query_tensor.half() 
            query_tensor = query_tensor.cuda()
            
            # Compute the inner-product
            similarities = torch.mm(query_tensor, database_tensor.T)
            
            # Get the Top-k examples
            batch_top_k_sim, batch_top_k_idx = similarities.topk(k, dim=1)

            # Register the result
            top_k_similarities[batch_start_pos:batch_end_pos] = batch_top_k_sim.float().cpu().numpy()
            top_k_indices[batch_start_pos:batch_end_pos] = batch_top_k_idx.cpu().numpy().astype(np.int32)

    # Clean the memory
    del database_tensor, query_tensor, similarities, batch_top_k_sim, batch_top_k_idx
    torch.cuda.empty_cache()
    gc.collect()
    
    return (top_k_similarities[:, -1].reshape(-1) + 1) / 2

# def knn_scores_IP_db_block_GPU(database, query, batch_size=50, block_size=200000, k=200, display=True, half_precision=False):
#     """
#     This functions applies a knn search on GPU. All the provided vectors should be already normalized. (by-block version when using the database)

#     database: The database for the KNN search, with a shape of (N_d, feat_dim).
#     query: The query for the KNN search, with a shape of (N_q, feat_dim).
#     batch_size: The number of examples to treat in each batch.
#     block_size: The number of examples to treat in each block from the database.
#     k: The number of the K nearest neighbours.
#     display: If we would like to display the progress bar.
#     half_precision: If we use FP16 to accelerate more.

#     Old Code:
#     # Move block results to CPU for merging
#     block_sim_cpu = block_sim.cpu().numpy()
#     block_idx_cpu = np.arange(block_start, block_end)

#     # Merge with existing top-K (on CPU and handle each query one by one)
#     for i in range(batch_end - batch_start):
#         # Get the position of each query
#         query_pos = batch_start + i
        
#         # Combine existing and new candidates
#         combined_sim = np.concatenate([
#             top_k_similarities[query_pos], 
#             block_sim_cpu[i]
#         ])
#         combined_idx = np.concatenate([
#             top_k_indices[query_pos],
#             block_idx_cpu
#         ])
        
#         # Get top-k with correct order
#         top_k_pos = np.argpartition(-combined_sim, k)[:k]  # Unsorted top-k positions
#         sorted_subset = np.argsort(-combined_sim[top_k_pos])
#         top_k_similarities[query_pos] = combined_sim[top_k_pos][sorted_subset]
#         top_k_indices[query_pos] = combined_idx[top_k_pos][sorted_subset]

#     # ## Merge with existing top-K (on CPU and handle all examples togther) 
#     # # Combine the results
#     # combined_sim = np.hstack([
#     #     top_k_similarities[batch_start:batch_end],
#     #     block_sim_cpu
#     # ])
#     # combined_idx = np.hstack([
#     #     top_k_indices[batch_start:batch_end],
#     #     np.tile(block_idx_cpu, (batch_end-batch_start, 1))
#     # ])
#     # # Get top-k with correct order
#     # top_k_pos = np.argpartition(-combined_sim, k, axis=1)[:, :k]
#     # sorted_subset = np.argsort(-np.take_along_axis(combined_sim, top_k_pos, axis=1), axis=1)
#     # top_k_similarities[batch_start:batch_end] = np.take_along_axis(
#     #     np.take_along_axis(combined_sim, top_k_pos, axis=1),
#     #     sorted_subset, axis=1
#     # )
#     # top_k_indices[batch_start:batch_end] = np.take_along_axis(
#     #     np.take_along_axis(combined_idx, top_k_pos, axis=1),
#     #     sorted_subset, axis=1
#     # )

#     del block_db_tensor, block_sim, block_sim_cpu
#     torch.cuda.empty_cache()
#     """
#     # Move the database to the GPU
#     database_tensor = torch.from_numpy(database)
    
#     # Initialize the result
#     nb_examples = query.shape[0]
#     top_k_similarities = np.full((nb_examples, k), -np.inf, dtype=np.float32)
#     top_k_indices = np.zeros((nb_examples, k), dtype=np.int64)

#     # Determine the information related to the batches
#     nb_batches = math.ceil(nb_examples / batch_size)
#     nb_blocks = math.ceil(database.shape[0] / block_size)

#     # Build the iterate index list
#     batch_progress_bar = list(range(nb_batches))
#     if display:
#         batch_progress_bar = tqdm(list(range(nb_batches)), desc='Processed batches')
    
#     # Execute the knn search
#     with torch.no_grad():
#         for batch_index in batch_progress_bar:
#             # Take the current batch data and move it to GPU
#             batch_start = batch_index*batch_size
#             batch_end = min((batch_index+1)*batch_size, nb_examples)
#             batch_query = query[batch_start:batch_end]
#             batch_query_tensor = torch.from_numpy(batch_query).cuda()
#             if half_precision:
#                 batch_query_tensor = batch_query_tensor.half()

#             # Initialize GPU buffers for current batch
#             current_top_k_sim = torch.from_numpy(top_k_similarities[batch_start:batch_end]).cuda()
#             current_top_k_idx = torch.from_numpy(top_k_indices[batch_start:batch_end]).cuda()

#             # Process database in blocks
#             for block_id in range(nb_blocks):
#                 # Get the block in the database
#                 block_start = block_id * block_size
#                 block_end = min((block_id + 1) * block_size, database.shape[0])
#                 block_db_tensor = database_tensor[block_start:block_end].cuda()
#                 if half_precision:
#                     block_db_tensor = block_db_tensor.half()

#                 # Compute block similarities
#                 block_sim = torch.mm(batch_query_tensor, block_db_tensor.T).float()
                
#                 # Prepare block indices
#                 block_indices = torch.arange(block_start, block_end).cuda().expand(batch_query_tensor.size(0), -1)
                
#                 # Combine with existing top-k
#                 combined_sim = torch.cat([current_top_k_sim, block_sim], dim=1)
#                 combined_idx = torch.cat([current_top_k_idx, block_indices], dim=1)
                
#                 # Get the new top k values
#                 new_topk = combined_sim.topk(k, dim=1)
                
#                 # Update current top-k
#                 current_top_k_sim = new_topk.values
#                 current_top_k_idx = torch.gather(combined_idx, 1, new_topk.indices)
                
#                 # Clean the memory
#                 del block_db_tensor, block_sim, block_indices, combined_sim, combined_idx, new_topk
#                 torch.cuda.empty_cache()

#             # Copy results back to CPU
#             top_k_similarities[batch_start:batch_end] = current_top_k_sim.cpu().numpy()
#             top_k_indices[batch_start:batch_end] = current_top_k_idx.cpu().numpy()
            
#             # Clean the memory
#             del batch_query_tensor
#             torch.cuda.empty_cache()

#     # Clean the memory
#     torch.cuda.empty_cache()
#     gc.collect()
    
#     return (top_k_similarities[:, -1].reshape(-1) + 1) / 2

# def next_optimal_dim(feat_dim):
#     """
#     This function returns the next optimal dimension greater than x.

#     x: The number.
#     """
#     padded_dim = 2048 if feat_dim <= 2048 else next_power_of_2(feat_dim)
#     return padded_dim

def knn_scores_IP_GPU_auto_dim(database, query, batch_size=50, k=200, display=True, half_precision=False):
    """
    This functions applies a knn search on GPU. All the provided vectors should be already normalized.
    This function also makes the feature dimension to become divisible by 2 with added zeros (to fit the GPU design).

    database: The database for the KNN search, with a shape of (N_d, feat_dim).
    query: The query for the KNN search, with a shape of (N_q, feat_dim).
    batch_size: The number of examples to treat in each batch.
    k: The number of the K nearest neighbours.
    display: If we would like to display the progress bar.
    half_precision: If we use FP16 to accelerate more.
    """
    # Get the original feature dimension
    feat_dim = database.shape[1]
    
    # Find the next power of two for the feature dimension and determine the padding size
    power_two_feat_dim = next_power_of_2(feat_dim)
    pad_size = power_two_feat_dim - feat_dim
        
     # Move the database to the GPU
    database_tensor = torch.from_numpy(database).cuda()
    if half_precision:
        database_tensor = database_tensor.half()

    # Pad the database with zeros if needed (on GPU)
    if pad_size > 0:
        database_tensor = torch.nn.functional.pad(database_tensor, (0, pad_size), "constant", 0)
    
    # Initialize the result
    nb_examples = query.shape[0]
    top_k_similarities = np.zeros((nb_examples, k), dtype=np.float32)
    top_k_indices = np.zeros((nb_examples, k), dtype=np.int32)

    # Determine the information related to the batches
    nb_batches = math.ceil(nb_examples / batch_size)

    # Build the iterate index list
    batch_progress_bar = list(range(nb_batches))
    if display:
        batch_progress_bar = tqdm(list(range(nb_batches)), desc='Processed batches')
    
    # Execute the knn search
    with torch.no_grad():
        for batch_index in batch_progress_bar:
            # Take the current batch data and move it to GPU
            batch_start_pos = batch_index*batch_size
            batch_end_pos = min((batch_index+1)*batch_size, nb_examples)
            batch_query = query[batch_start_pos:batch_end_pos]
            query_tensor = torch.from_numpy(batch_query).cuda()
            if half_precision:
                query_tensor = query_tensor.half()

            # Pad the database with zeros if needed (on GPU)
            if pad_size > 0:
                query_tensor = torch.nn.functional.pad(query_tensor, (0, pad_size), "constant", 0)
            
            # Compute the inner-product
            similarities = torch.mm(query_tensor, database_tensor.T)
            
            # Get the Top-k examples
            batch_top_k_sim, batch_top_k_idx = similarities.topk(k, dim=1)
            
            # Register the result
            top_k_similarities[batch_start_pos:batch_end_pos] = batch_top_k_sim.float().cpu().numpy()
            top_k_indices[batch_start_pos:batch_end_pos] = batch_top_k_idx.cpu().numpy().astype(np.int32)

    # Clean the memory
    del database_tensor, query_tensor, similarities, batch_top_k_sim, batch_top_k_idx
    torch.cuda.empty_cache()
    gc.collect()
    
    return (top_k_similarities[:, -1].reshape(-1) + 1) / 2

