import networkx as nx
from source.controller.retriever.passage_retrieval import main_modified
from munch import DefaultMunch
from data.dataset import DATASET_DIR, dict_path, config

def page_rank(graph, alpha=0.85):
    return nx.pagerank(graph, alpha)

def top_k(queries, dataset_name, k=10, retriever=None, tokenizer=None, collection = None):
    if dataset_name == "decode_dict":
        answers = ""
        for word in queries[0].split():
            answers += f"{word} : {str(collection[word])}\n"
    
        answers = [{"ctxs": [{"text" : str(answers)}]}]
        
    elif "med_records" in dataset_name :
        if type(queries) != list:
            queries = [queries]
        for query in queries:
            patient_id = query.split()[-1] + ".txt"
            args = config(dataset_name, k)
            args["passages_embeddings"] = args["passages_embeddings"] + patient_id + "/"
            args["passages"] = args["passages"] + patient_id + "/data.tsv"
            
            if "weaken" in dataset_name:
                if "medium" in dataset_name:
                    args["n_subquantizers"] = 8
                    args["n_bits"] = 2
                elif "max" in dataset_name:
                    args["n_subquantizers"] = 16
                    args["n_bits"] = 2
                else:
                    args["n_subquantizers"] = 4
                    args["n_bits"] = 2
            
            
            
            args = DefaultMunch.fromDict(args)
            answers = main_modified(args, [query], retriever, tokenizer, collection )

    else:   
    
        args = config(dataset_name, k)
        
        args = DefaultMunch.fromDict(args)
        
        answers = main_modified(args, queries, retriever, tokenizer, collection )
    return answers