import os
import csv
import argparse
from collections import defaultdict
from tqdm import tqdm
import pytrec_eval
import pandas as pd

def load_beir_qrels(qrels_file):
    qrels = {}
    with open(qrels_file) as f:
        tsvreader = csv.DictReader(f, delimiter="\t")
        for row in tsvreader:
            qid = row["query-id"]
            pid = row["corpus-id"]
            rel = int(row["score"])
            if qid in qrels:
                qrels[qid][pid] = rel
            else:
                qrels[qid] = {pid: rel}
    return qrels



if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name')
    parser.add_argument('--qrels_path')
    parser.add_argument('--trec_path')
    parser.add_argument('--log_path')
    
    args = parser.parse_args()

    qrels = load_beir_qrels(args.qrels_path)
    
    with open(args.trec_path, 'r') as f_run:
        run = pytrec_eval.parse_run(f_run)
        
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {'ndcg_cut.10'})
    eval_results = evaluator.evaluate(run)
    
    for query_id, query_measures in sorted(eval_results.items()):
        for measure, value in sorted(query_measures.items()):
            pass
        
    for measure in sorted(query_measures.keys()):
        res = pytrec_eval.compute_aggregated_measure(
            measure,
            [query_measures[measure] for query_measures in eval_results.values()])
        
    
    table = {}
    if os.path.exists(args.log_path):
        table = pd.read_csv(args.log_path)

    # if args.dataset_name not in table.columns:
    table[args.dataset_name] = ['{:.3f}'.format(res)]
    pd.DataFrame(table).to_csv(args.log_path, mode='w',index=False)
