from nltk.corpus import stopwords
import string
import json
import nltk


def get_topk_word_ner_data(file, config):
    stop_words = config['stop_words']
    fin = open(file)
    lines = fin.readlines()
    word_to_count = {}
    for line in lines:
        line = line.strip()
        if len(line) == 0:
            continue
        word = line.split()[0].lower()
        if word in stop_words:
            continue
        if word in word_to_count:
            word_to_count[word] += 1
        else:
            word_to_count[word] = 1
    sorted_words = sorted(word_to_count.items(), key=lambda x: x[1], reverse=True)
    word_list = [x[0] for x in sorted_words]
    # print(word_list)
    fin.close()
    return word_list[:config['topk']]


def get_topk_word_qa_data(file, config):
    tokenizer = nltk.RegexpTokenizer(r"\w+")
    stop_words = config['stop_words']
    with open(file) as f:
        data = json.load(f)
    word_to_count = {}
    for article in data['data']:
        paragraphs = article['paragraphs']
        for paragraph in paragraphs:
            para = paragraph['context']
            # para_words = para.split()
            para_words = tokenizer.tokenize(para.lower())
            for word in para_words:
                if word not in stop_words:
                    if word in word_to_count:
                        word_to_count[word] += 1
                    else:
                        word_to_count[word] = 1
            qas = paragraph['qas']
            for qa in qas:
                # ques_words = qa['question'].split()
                ques_words = tokenizer.tokenize(qa['question'].lower())
                for word in ques_words:
                    if word not in stop_words:
                        if word in word_to_count:
                            word_to_count[word] += 1
                        else:
                            word_to_count[word] = 1
                # ans_words = qa['answers'][0]['text'].split()
                ans_words = tokenizer.tokenize(qa['answers'][0]['text'].lower())
                for word in ans_words:
                    if word not in stop_words:
                        if word in word_to_count:
                            word_to_count[word] += 1
                        else:
                            word_to_count[word] = 1
    sorted_words = sorted(word_to_count.items(), key=lambda x: x[1], reverse=True)
    word_list = [x[0] for x in sorted_words]
    print('word num', len(word_list))
    return word_list[:config['topk']]


def get_overlap(in_domain_words, out_domain_words):
    assert len(in_domain_words) == len(out_domain_words)
    overlap_num = 0
    for word in in_domain_words:
        if word in out_domain_words:
            overlap_num += 1
    return overlap_num / len(in_domain_words)


def run_vocab_overlap_ner_data():
    in_domain_file = 'xdomain-person/large_twitter.txt'
    out_domain_file = 'xdomain-person/large_twitter.txt'
    stop_words = set(stopwords.words('english'))
    punctuation_set = set(string.punctuation)
    for punctuation in punctuation_set:
        stop_words.add(punctuation)
    # print('stop_words', stop_words)
    config = {'topk': 1000, 'stop_words': stop_words}
    topk_words_in_domain = get_topk_word_ner_data(in_domain_file, config)
    topk_words_out_domain = get_topk_word_ner_data(out_domain_file, config)
    # print('topk words in domain', topk_words_in_domain)
    # print('topk words out domain', topk_words_out_domain)
    overlap_rate = get_overlap(topk_words_in_domain, topk_words_out_domain)
    print('overlap rate', overlap_rate)


def run_vocab_overlap_qa_data():
    in_domain_file = 'QA-data/xdomain-QA-squad/large_squad.json'
    out_domain_file = 'QA-data/xdomain-QA-squad/large_triviaqa.json'
    stop_words = set(stopwords.words('english'))
    punctuation_set = set(string.punctuation)
    for punctuation in punctuation_set:
        stop_words.add(punctuation)
    # print('stop_words', stop_words)
    config = {'topk': 1000, 'stop_words': stop_words}
    topk_words_in_domain = get_topk_word_qa_data(in_domain_file, config)
    topk_words_out_domain = get_topk_word_qa_data(out_domain_file, config)
    # print('topk words in domain', topk_words_in_domain)
    # print('topk words out domain', topk_words_out_domain)
    overlap_rate = get_overlap(topk_words_in_domain, topk_words_out_domain)
    print('overlap rate', overlap_rate)


if __name__ == '__main__':
    # run_vocab_overlap_ner_data()
    run_vocab_overlap_qa_data()

