import json
import os
import argparse
import numpy as np
import random
import spacy
from rank_bm25 import BM25Okapi

from tqdm import tqdm

spacy_tagger = spacy.load("en_core_web_sm")


def sample_pairs(args):
    wiki_docs = json.load(open(args.wiki_file))
    titles = [k for k in wiki_docs.keys()]
    dev_data = json.load(open(args.dev_file))['data']

    passage_pairs = []
    for article in tqdm(dev_data):
        title = article['title']

        if title not in titles:
            print(f'missing title: {title}')
            continue
        
        for paragraph in article['paragraphs']: # TODO change this to gold passages in DPR
            gold_passage = paragraph['context']
            qas = paragraph['qas']
            question = [qa['question'] for qa in qas][0] # use one of questions
            answer = [(qa['answers'][0]['text'], qa['answers'][0]['answer_start']) for qa in qas][0]

            if answer[0] not in gold_passage:
                print("Unanswerable question")
                continue

            # Get topic negatives
            neg_title = title
            topic_neg = ''
            while (neg_title == title) or (answer[0] in topic_neg):
                neg_title = titles[np.random.randint(len(titles))]
                topic_neg = wiki_docs[neg_title][np.random.randint(len(wiki_docs[neg_title]))]['text']
            
            # Get entail negatives
            # Hard-random or BM25 entail neg # TODO maybe fix BM25 to include entire Wiki for TF/IDF
            bm25 = True
            passage_set = [passage['text'] for passage in wiki_docs[title][:]]
            if bm25:
                tokenized_corpus = [passage.split(" ") for passage in passage_set]
                bm25 = BM25Okapi(tokenized_corpus)
                tokenized_query = question.split(" ")
                passage_scores = bm25.get_scores(tokenized_query)
                passage_set = np.array(passage_set)[np.argsort(passage_scores)[::-1]].tolist()
            else:
                random.shuffle(passage_set)

            entail_neg = None
            for passage in passage_set:
                if answer[0] not in passage:
                    entail_neg = passage
                    break
            '''
            '''

            # Synthetic entail neg
            '''
            sents = [X.text for X in spacy_tagger(gold_passage).sents]
            filtered_sents = [sent for sent in sents if answer[0] not in sent]
            entail_neg = None
            if len(filtered_sents) == len(sents) - 1: # get only single-sentence removed samples
                entail_neg = ' '.join(filtered_sents)
            '''

            if entail_neg is None:
                print("Cannot find entail negatives")
                continue

            if len(entail_neg.strip()) == 0:
                print("no sent for entail negatives")
                continue

            assert neg_title != title
            assert answer[0] not in topic_neg
            assert answer[0] not in entail_neg

            passage_pairs.append({
                'question': question,
                'answer': answer,
                'title': title,
                'gold_passage': gold_passage,
                'entail_neg_passage': entail_neg,
                'neg_title': neg_title,
                'topic_neg_passage': topic_neg,
            })

    with open(args.output_path, 'w') as fp:
        json.dump({'data': passage_pairs}, fp)
    print(f'passage pairs saved at {args.output_path} with {len(passage_pairs)} pairs')


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('wiki_file')
    parser.add_argument('dev_file')
    parser.add_argument('--output_path')
    return parser.parse_args()


def main():
    args = get_args()
    sample_pairs(args)


if __name__ == '__main__':
    main()
