import json
import copy
import os
import argparse
import numpy as np
import random
import spacy
from rank_bm25 import BM25Okapi

from tqdm import tqdm

spacy_tagger = spacy.load("en_core_web_sm")


def phrase_negatives(args):
    train_data = json.load(open(args.train_file))['data']
    wiki_docs = json.load(open(args.wiki_file))
    titles = [k for k in wiki_docs.keys()]
    contexts = [(v['text'], title) for title, val in wiki_docs.items() for v in val]
    random.shuffle(contexts)

    new_train_data = copy.deepcopy(train_data)
    question_to_data = {}
    no_ans_cnt = 0
    no_q_cnt = 0
    no_hard_cnt = 0

    for a_idx, article in tqdm(enumerate(train_data)):
        title = article['title']

        if title not in titles:
            print(f'missing title: {title}')
            continue
        
        for p_idx, paragraph in enumerate(article['paragraphs']):
            gold_passage = paragraph['context']
            qas = paragraph['qas']
            assert len(qas) == 1 # only one quesiton is given for NQ
            question = [qa['question'] for qa in qas][0]
            answer = [qa['answers'] for qa in qas][0]

            if len(answer) == 0:
                del new_train_data[a_idx]['paragraphs'][p_idx]
                no_ans_cnt += 1
                continue
            
            rand_idx = np.random.choice(len(contexts))
            tolerance = 0
            while answer[0]['text'].lower() not in contexts[rand_idx][0].lower():
                rand_idx = np.random.choice(len(contexts))
                tolerance += 1
                # if tolerance > len(contexts):
                if tolerance > 1000000:
                    break

            if answer[0]['text'].lower() not in contexts[rand_idx][0].lower():
                print(f'over tolerance: {answer[0]["text"]}')
                no_ans_cnt += 1
                continue 

            hard_negative, hard_title = contexts[rand_idx]
            new_train_data[a_idx]['paragraphs'][p_idx]['neg_title'] = hard_title
            new_train_data[a_idx]['paragraphs'][p_idx]['neg_context'] = hard_negative

            if question in question_to_data:
                import pdb; pdb.set_trace()
                print(question)
            question_to_data[question] = new_train_data[a_idx]['paragraphs'][p_idx]

    new_train_data = [data for data in new_train_data if len(data['paragraphs']) > 0]
    print(f'original dataset size: {len(train_data)} samples')
    print(f'no ans: {no_ans_cnt}, no question: {no_q_cnt}, no hard: {no_hard_cnt}, sum: {no_ans_cnt + no_q_cnt + no_hard_cnt}')
    with open(args.output_path, 'w') as fp:
        json.dump({'data': new_train_data}, fp)
    print(f'bm25 neg saved at {args.output_path} with {len(new_train_data)} samples')


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('train_file')
    parser.add_argument('wiki_file')
    parser.add_argument('--output_path')
    return parser.parse_args()


def main():
    args = get_args()
    phrase_negatives(args)


if __name__ == '__main__':
    main()
