import torch
from utils_proof.similarity import *

class Retriever:
    def __init__(self, train_embs, sim_name) -> None:
        self.train_embs = train_embs
        self.sim_func = get_simlarity(sim_name)
    
    def retrieve(self, test_emb, k):
        if torch.is_tensor(test_emb):
            sims = self.sim_func(torch.unsqueeze(test_emb, 0), self.train_embs)
        else:
            sims = self.sim_func(test_emb, self.train_embs)
        sorted_sims_with_index = sorted(enumerate(sims), key=lambda x: x[1], reverse=True)
        topk_indices = [index for index, _ in sorted_sims_with_index[:k]]
        return topk_indices