import json, csv, argparse, logging, sys
import numpy as np
from tqdm import tqdm
from pathlib import Path
from dpr.options import setup_logger
from dpr.utils.evaluation_utils import (
    mean_average_precision,
    precision_k,
    recall_k,
    hit_k,
    mrr
)

import hydra
from omegaconf import DictConfig, OmegaConf

logger = logging.getLogger()
setup_logger(logger)
csv.field_size_limit(sys.maxsize)

#file_format = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
#log_filehandler = logging.FileHandler('eval_metrics.log', mode='w')
#log_filehandler.setFormatter(file_format)
#logger.addHandler(log_filehandler)

def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--score_file',
        required=True,
        type=str
    )
    parser.add_argument(
        '--context_file',
        type=str
    )
    parser.add_argument(
        '--query_file',
        required=True,
        type=str
    )
    parser.add_argument(
        '--answer_file',
        required=True,
        type=str
    )
    parser.add_argument(
        '--topk',
        type=int,
        default=5
    )
    parser.add_argument(
        '--topk_output',
        type=str,
    )
    args = parser.parse_args()
    return args

@hydra.main(config_path="conf", config_name="evaluation")
def evaluate_metrics(cfg : DictConfig):
    
    all_qids = list()
    query_dict = dict()
    logger.info('Reading queries from file %s ...' % cfg.query_file)
    with open(cfg.query_file,'r',encoding='utf-8') as query_file:
        query_csv = csv.reader(query_file, delimiter='\t')
        for que, qid, _ in query_csv:
            all_qids.append(qid)
            query_dict[qid] = que
    
    logger.info('Reading scores from file %s ...' % cfg.score_file)
    with open(cfg.score_file,'r') as json_file:
        dpr_results = json.load(json_file)
    
    all_psgs = dict()
    if cfg.context_file:
        with open(cfg.context_file, 'r') as context_f:
            context_csv = csv.reader(context_f, delimiter='\t')
            for line in context_csv:
                all_psgs[line[0]] = line[1]

    logger.info('Simplifying scores ...')
    pid_set = set()
    score_dict = dict()
    for entry in tqdm(dpr_results):
        if entry['qid'] in all_qids:
            qid = entry['qid']
            score_dict[qid] = list()
            ctxts = entry['ctxs']
            for ctx in ctxts:
                pid = ctx['id']
                pid_set.add(pid)
                score = eval(ctx['score'])
                score_dict[qid].append((pid, score))

    logger.info('Loading answers from file ...')
    answer_path = Path(cfg.answer_file)
    assert answer_path.suffix in ['.json','.jsonl','.csv','.tsv']
    if answer_path.suffix in ['.json','.jsonl']:
        with open(cfg.answer_file, 'r', encoding='utf-8') as json_file:
            qrel_dict = json.load(json_file)
            for k in qrel_dict:
                qrel_dict[k] = qrel_dict[k][0]
    else:
        with open(cfg.answer_file,'r',encoding='utf-8') as qrel_file:
            reader = csv.reader(qrel_file, delimiter='\t')
            qrel_dict = dict()
            for qid, _, pid, _ in reader:
                if qid not in qrel_dict:
                    qrel_dict[qid] = list()
                qrel_dict[qid].append(pid)
    
    logger.info('Calculating Metrics')
    sorted_scores_dpr = list()
    topk_docs = dict()

    for qid in tqdm(query_dict.keys()):
        if qid not in qrel_dict:
            continue
        label_pids = qrel_dict[qid]
        score_dict[qid].sort(key=lambda entry : entry[1], reverse=True)
        sorted_line = np.array(
            [(pid, score, pid in label_pids) for pid, score in score_dict[qid]],
            dtype=[('id','U20'),('score',float),('label',float)]
        )
        sorted_scores_dpr.append(sorted_line)

        topk_docs[qid] = dict()
        topk_docs[qid]['question'] = query_dict[qid]
        topk_docs[qid]['pid'] = list()
        topk_docs[qid]['correct_pid'] = list()
        topk_docs[qid]['incorrect_pid'] = list()
        if type(label_pids) == list:
            topk_docs[qid]['source_pid'] = label_pids[0]
        else:
            topk_docs[qid]['source_pid'] = label_pids
        if cfg.context_file and topk_docs[qid]['source_pid'] in all_psgs:
            topk_docs[qid]['source_psg'] = all_psgs[topk_docs[qid]['source_pid']]
        in_topk = False
        for i, (pid, score) in enumerate(score_dict[qid]):
            if i < cfg.topk:
                if cfg.context_file:
                    topk_docs[qid]['pid'].append((pid, all_psgs[pid], pid in label_pids, score))
                else:
                    topk_docs[qid]['pid'].append((pid, pid in label_pids, score))
                
                if pid in label_pids:
                    topk_docs[qid]['correct_pid'].append(pid)
                #else:
                #    topk_docs[qid]['incorrect_pid'].append(pid)

                in_topk = in_topk or (pid in label_pids)
            if in_topk:
                topk_docs[qid]['source_score'] = score
            else:
                topk_docs[qid]['source_score'] = 0.0
        topk_docs[qid]['incorrect_pid'] = list(set(label_pids) - set(topk_docs[qid]['correct_pid']))
        topk_docs[qid]['label'] = in_topk

    sorted_scores_dpr = np.array(sorted_scores_dpr)
    logger.info('MAP_DPR : %.4f' % (mean_average_precision(sorted_scores_dpr,k=cfg.topk)))
    logger.info('Precision@%d_DPR : %.4f' % (cfg.topk, precision_k(sorted_scores_dpr, k=cfg.topk)))
    logger.info('Recall@%d_DPR : %.4f' % (1, recall_k(sorted_scores_dpr, k=1)))
    logger.info('Recall@%d_DPR : %.4f' % (5, recall_k(sorted_scores_dpr, k=5)))
    logger.info('Recall@%d_DPR : %.4f' % (10, recall_k(sorted_scores_dpr, k=10)))    
    logger.info('Recall@%d_DPR : %.4f' % (50, recall_k(sorted_scores_dpr, k=50)))
    logger.info('Hit@%d_DPR : %.4f' % (cfg.topk, hit_k(sorted_scores_dpr, k=cfg.topk)))
    logger.info('MRR_DPR : %.4f' % mrr(sorted_scores_dpr))
    logger.info('End of evaluation')

    if cfg.topk_output:
        with open('topk_docs.json', 'w') as doc_file:
            json.dump(topk_docs, doc_file, indent=4)

        topk_pred = {qid : {'top': [pid[0] for pid in topk_docs[qid]['pid']], 'correct' : topk_docs[qid]['correct_pid'], 'incorrect' : topk_docs[qid]['incorrect_pid']} for qid in topk_docs}
        with open('psg_prediction.json','w') as prediction_file:
            json.dump(topk_pred, prediction_file, indent=4)

if __name__ == "__main__":
    logger.info("Sys.argv: %s", sys.argv)
    hydra_formatted_args = []
    # convert the cli params added by torch.distributed.launch into Hydra format
    for arg in sys.argv:
        if arg.startswith("--"):
            hydra_formatted_args.append(arg[len("--") :])
        else:
            hydra_formatted_args.append(arg)
    logger.info("Hydra formatted Sys.argv: %s", hydra_formatted_args)
    sys.argv = hydra_formatted_args

    evaluate_metrics()