'''TopK Retriever'''
import torch
import torch.nn as nn

import numpy as np
from tqdm import tqdm

from common import get_logger
from .base_retrieve import BaseRetriever

logger = get_logger(__name__)


class TopKRetriever(BaseRetriever):
    def __init__(self, task,ice_dataloader, candidate_dataloader, inferencer, device, metric_model, metric_tokenizer):
        super().__init__(task, ice_dataloader, candidate_dataloader, inferencer, device, metric_model, metric_tokenizer)
        
        
    def knn_search(self, test_forward, ice_embed_list, ice_forward, base_index, ice_num, texts, labels):
        rtr_idx_list = [[] for _ in range(len(test_forward))]
        rtr_score_list = [[] for _ in range(len(test_forward))]
        cos = nn.CosineSimilarity(dim=1, eps=1e-6)

        logger.info("Retrieving data for test set...")


        for entry in tqdm(test_forward):
            idx = entry['metadata']['id']
            embed = np.expand_dims(entry['embed'], axis=0)
            rtr_idx_list[idx] = base_index.search(embed, ice_num)[1][0].tolist()
            cos_score = []
            for ice_idx in rtr_idx_list[idx]:
                ice_embed = np.expand_dims(ice_forward[ice_idx]['embed'], axis=0)
                cos_score.append(float(np.array(cos(torch.from_numpy(ice_embed), torch.from_numpy(embed)))))
            rtr_score_list[idx] = cos_score
        
        return rtr_idx_list, rtr_score_list


    def raw_retrieve(self, test_forward, ice_embed_list, ice_forward, base_index, ice_num, texts, labels):
        return self.knn_search(test_forward, ice_embed_list, ice_forward, base_index, ice_num, texts, labels)

    
    
