from .base_retrieve import BaseRetriever
from .topk import TopKRetriever
from .rm import RandomRetriever

def get_retriever(retriever_type, task, ice_dataloader, candidate_dataloader, inferencer, device, metric_model, metric_tokenizer):
    if retriever_type == 'topk':
        print('topk')
        retriever = TopKRetriever(task, ice_dataloader, candidate_dataloader,  inferencer=inferencer,device=device, metric_model = metric_model, metric_tokenizer = metric_tokenizer)
    elif retriever_type == 'rm':
        print('topk')
        retriever = RandomRetriever(task, ice_dataloader, candidate_dataloader,  inferencer=inferencer,device=device, metric_model = metric_model, metric_tokenizer = metric_tokenizer)
    
    else:
        print("Error Retriever")
    return retriever
    