import json
import copy
import os
import argparse
import numpy as np
import random
import spacy
from rank_bm25 import BM25Okapi

from tqdm import tqdm

spacy_tagger = spacy.load("en_core_web_sm")


def bm25_negatives(args):
    train_data = json.load(open(args.train_file))['data']
    # wiki_docs = json.load(open(args.wiki_file))
    # titles = [k for k in wiki_docs.keys()]
    dpr_train_data = json.load(open(args.dpr_train_file))
    dpr_train_data = {data['question'].lower(): data for data in dpr_train_data}
    new_train_data = copy.deepcopy(train_data)
    question_to_data = {}
    no_ans_cnt = 0
    no_q_cnt = 0
    no_hard_cnt = 0

    for a_idx, article in tqdm(enumerate(train_data)):
        title = article['title']

        # if title not in titles:
        #     print(f'missing title: {title}')
        #     continue
        
        for p_idx, paragraph in enumerate(article['paragraphs']):
            gold_passage = paragraph['context']
            qas = paragraph['qas']
            assert len(qas) == 1 # only one quesiton is given for NQ
            question = [qa['question'] for qa in qas][0]
            answer = [qa['answers'] for qa in qas][0]

            if len(answer) == 0:
                del new_train_data[a_idx]['paragraphs'][p_idx]
                no_ans_cnt += 1
                continue

            hard_negatives = dpr_train_data.get(question.lower(), None)
            if hard_negatives is None:
                # print("Question not found")
                del new_train_data[a_idx]['paragraphs'][p_idx]
                no_q_cnt += 1
                continue

            if len(hard_negatives['hard_negative_ctxs']) == 0:
                # print("No hard negatives found")
                del new_train_data[a_idx]['paragraphs'][p_idx]
                no_hard_cnt += 1
                continue

            hard_negatives = hard_negatives['hard_negative_ctxs'][0]
            new_train_data[a_idx]['paragraphs'][p_idx]['neg_title'] = hard_negatives['title']
            new_train_data[a_idx]['paragraphs'][p_idx]['neg_context'] = hard_negatives['text']

            if question in question_to_data:
                import pdb; pdb.set_trace()
                print(question)
            question_to_data[question] = new_train_data[a_idx]['paragraphs'][p_idx]

    new_train_data = [data for data in new_train_data if len(data['paragraphs']) > 0]
    print(f'original dataset size: {len(train_data)} samples')
    print(f'no ans: {no_ans_cnt}, no question: {no_q_cnt}, no hard: {no_hard_cnt}, sum: {no_ans_cnt + no_q_cnt + no_hard_cnt}')
    with open(args.output_path, 'w') as fp:
        json.dump({'data': new_train_data}, fp)
    print(f'bm25 neg saved at {args.output_path} with {len(new_train_data)} samples')


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('train_file')
    parser.add_argument('dpr_train_file')
    parser.add_argument('wiki_file')
    parser.add_argument('--output_path')
    return parser.parse_args()


def main():
    args = get_args()
    bm25_negatives(args)


if __name__ == '__main__':
    main()
