import numpy as np

from src.entity.KnowledgeBase import KnowledgeBase
from src.entity.problems.Problem import Problem
from sklearn.metrics.pairwise import cosine_similarity



def GetRelevantKnowledgeByCosSimilarityUseCase(problem: Problem, knowledge_base: KnowledgeBase, top_k: int,
                                               start_similarity, end_similarity, exclude_self: bool, random_from_k:bool, exclude_reference: bool) -> np.array:
    """
    Runs the retrieval process to find the top-k most similar items for a target problem based on the embeddings.
    """
    # size: 1, dim
    target = problem.embedding
    # size: N, dim
    items = knowledge_base.embeddings

    target = target.reshape(1, -1)
    items = items.reshape(items.shape[0], -1)


    assert target.shape[1] == items.shape[1], "The dimensions of the target and items do not match."

    # size: N
    similarity = cosine_similarity(target, items).flatten()
    # size: N
    sorted_indices = np.argsort(similarity)[::-1]
    filtered_indices = [i for i in sorted_indices if start_similarity <= similarity[i] <= end_similarity]
    filtered_indices = [i for i in filtered_indices if
                        not (exclude_self and problem.id == knowledge_base.knowledges[i].id) and not (exclude_reference and problem.id == knowledge_base.knowledges[i].reference_to)]

    # filtered_indices = filtered_indices[:top_k]
    # randomly choose top_k.
    if len(filtered_indices) > top_k and random_from_k:
        import random

        np.random.seed(906)
        filtered_indices = np.random.choice(filtered_indices, top_k, replace=False)
    if not random_from_k:
        filtered_indices = filtered_indices[:top_k]



    # get those filtered by threshold
    filtered_knowledges = [knowledge_base.knowledges[i] for i in filtered_indices]
    similarity = similarity[filtered_indices].tolist()

    print(similarity)
    assert len(filtered_knowledges) == len(similarity)
    if len(filtered_knowledges) == 0:
        return None, None
    return filtered_knowledges, similarity