#! /usr/bin/env python3
# coding=utf-8


import os
import re
import time
import pickle
import sng_parser
import spacy
import random
import pandas as pd
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from sentence_transformers import SentenceTransformer, util
from sklearn.model_selection import train_test_split
from functools import partial
from nltk.tokenize import RegexpTokenizer
from parallel_configs import CONFIGS
from spacy.tokens import Token
from spacy.lang.en.stop_words import STOP_WORDS  # import stop words from language data

tokenizer = RegexpTokenizer(r'\w+')
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# theta_NE = 0.5
# theta_ENT = 0.3
# theta_REL = 0.2

# for allsides
BETA = 1

RANDOM_SEED = 25536

# replace to trf if speed ok
nlp = spacy.load("en_core_web_sm")
lemmatizer = nlp.get_pipe("lemmatizer")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

stop_words_getter = lambda token: token.is_stop or token.lower_ in STOP_WORDS or token.lemma_ in STOP_WORDS
Token.set_extension('is_stop', getter=stop_words_getter)  # set attribute with getter


def text_normalizer(text):
    text = str(text)
    # Check characters to see if they are in punctuation
    text = re.sub(r'\n+', '. ', text)

    # remove URLs
    text = re.sub('((www\.[^\s]+)|(https?://[^\s]+)|(http?://[^\s]+))', '', text)
    text = re.sub(r'http\S+', '', text)
    # remove usernames
    text = re.sub('@[^\s]+', '', text)
    # remove the # in #hashtag

    text = re.sub(r'#([^\s]+)', r'\1', text)
    text = re.sub(r"[^a-zA-Z.,!?]+", r" ", text)
    text = re.sub(r'\s([,?.!"](?:\s|$))', r'\1', text)
    return text


def get_allsides_data(data_path):
    """
    load yelp dataset into list of sentences
    :param data_path:
    :return:
    """
    features_set = []
    features_sentences = []
    features_sentences_index = []
    index = 0
    features_sentences_index.append(index)
    for i in range(0, 2299):
        lines = open(data_path + 'story_' + str(i) + '.txt', 'r', encoding='utf-8').readlines()[:]
        features = []
        for i, line in enumerate(lines):
            seq = line.strip('\n')
            if len(str(seq).split(' ')) < 10:
                continue
            features.append(seq)
            features_sentences.append(seq)

        index += len(features)
        features_set.append(features)
        features_sentences_index.append(index)
    return features_set, features_sentences, features_sentences_index


def get_precision_recall(a, b):
    """
    return precision and recall of a and b
    :param a:
    :param b:
    :return:
    """
    num_matches = len(list(set(a) & set(b)))
    if len(a) == 0 or len(b) == 0:
        return 0, 0
    else:
        return num_matches / len(list(set(b))), num_matches / len(list(set(a)))


def spice_f1(p, r):
    """
    return the f1 score given p and r
    :param p:
    :param r:
    :return:
    """
    if p == 0 or r == 0:
        return 0
    else:
        return (2 * p * r) / (p + r)


# def compute_spice(ENT_a, REL_a, ENT_b, REL_b):
#     """
#     compute spice score between pairs
#     :param ENT_a:
#     :param REL_a:
#     :param ENT_b:
#     :param REL_b:
#     :return:
#     """
#     assert theta_ENT + theta_REL == 1, "Please make sure the sum of THETAs is 1."
#
#     p_ENT, r_ENT = get_precision_recall(ENT_a, ENT_b)
#     p_REL, r_REL = get_precision_recall(REL_a, REL_b)
#
#     return theta_ENT * spice_f1(p_ENT, r_ENT) + theta_REL * spice_f1(p_REL, r_REL)


# def _get_random(sent1, sent2s, n):
#     temp_dict, _sent2s = {}, []
#     temp_dict['sent1'] = sent1
#     if len(sent2s) <= n:
#         _sent2s = sent2s
#     else:
#         for i in [random.randint(0, len(sent2s) - 1) for _ in range(n)]:
#             _sent2s.append(sent2s[i])
#     temp_dict['sent2'] = _sent2s
#     return temp_dict

def _get_random(sent1, sent2s, n):
    temp_dict, _sent2s = {}, []
    temp_dict['sent1'] = sent1
    if len(sent2s) <= n:
        _sent2s = sent2s
    else:
        for i in [random.randint(0, len(sent2s) - 1) for _ in range(n)]:
            _sent2s.append(sent2s[i])
    temp_dict['sent2'] = _sent2s
    return temp_dict


def pick_by_random(sent1s, sent2s, n, cache_path, out_csv_name):
    """
    randomly choose n sents from sent2s as parallels of each sent1 in sent1s
    :param sent1s:
    :param sent2s:
    :param n:
    :param cache_path:
    :param out_csv_name:
    :return:
    """
    if not os.path.exists(cache_path):
        os.makedirs(cache_path)

    out_df = pd.DataFrame(columns=['sent1', 'sent2'])
    _part_get_random = partial(_get_random, n=n)
    args = []
    for i in range(len(sent1s)):
        # for j in range(len(sent1s[i])):
        args.append((sent1s[i], sent2s))

    with Pool(cpu_count()) as proc:
        _temp_dicts = list(
            tqdm(proc.starmap(_part_get_random, args, ),
                 total=len(sent1s)))
        out_df = out_df.append(pd.DataFrame(_temp_dicts), ignore_index=True)
    out_df.to_csv(os.path.join(cache_path, out_csv_name), index=False)

    return out_df


def iterate_sents_batches(sents, batch_size):
    """
    support iterate the sents (and its kgs) by batches
    :param sents:
    :param batch_size:
    :return:
    """
    assert isinstance(sents, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = sents[ofs * batch_size:(ofs + 1) * batch_size]
        if len(batch) <= 1:
            break
        yield batch
        ofs += 1


def iterate_batches_lm(sents, sents_embedding, batch_size):
    """
    support iterate the sents (and its embeddings) by batches
    :param sents:
    :param sents_embedding:
    :param batch_size:
    :return:
    """
    assert isinstance(sents, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = (sents[ofs * batch_size:(ofs + 1) * batch_size],
                 sents_embedding[ofs * batch_size:(ofs + 1) * batch_size])
        if len(batch[0]) <= 1 or len(batch[1]) <= 1:
            break
        yield batch
        ofs += 1


def iterate_batches_lm_kg(sents, sents_embedding, sents_kg, batch_size):
    """
    support iterate the sents (and its embeddings and kg) by batches
    :param sents:
    :param sents_embedding:
    :param sents_kg:
    :param batch_size:
    :return:
    """
    assert isinstance(sents, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = (sents[ofs * batch_size:(ofs + 1) * batch_size],
                 sents_embedding[ofs * batch_size:(ofs + 1) * batch_size],
                 sents_kg[ofs * batch_size:(ofs + 1) * batch_size])
        if len(batch[0]) <= 1 or len(batch[1]) <= 1 or len(batch[2]) <= 1:
            break
        yield batch
        ofs += 1


def iterate_kg_batches(sents, sents_kg, batch_size):
    """
    support iterate the sents (and its kgs) by batches
    :param sents:
    :param sents_kg:
    :param batch_size:
    :return:
    """
    assert isinstance(sents, list)
    assert isinstance(sents_kg, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = (sents[ofs * batch_size:(ofs + 1) * batch_size],
                 sents_kg[ofs * batch_size:(ofs + 1) * batch_size])
        if len(batch[0]) <= 1 or len(batch[1]) <= 1:
            break
        yield batch
        ofs += 1


def return_entities(sent):
    return [token.lemma_.lower() for token in nlp(sent) if
               token.pos_ in ['PROPN', 'NOUN'] and not token.is_stop]


def build_kg(sents, cache_path, cache_name):
    """
    build kg for sents
    :param sents:
    :param cache_path:
    :param cache_name:
    :return: built kg for sents
    """
    with Pool(cpu_count()) as proc:  # save time from 30min to 3min (on 20 cpus machine 2)
        sents_kg = list(
            tqdm(proc.imap(return_entities, sents, ),
                 total=len(sents)))

        print("Saving to pickle ...")
        if not os.path.exists(cache_path):
            os.makedirs(cache_path)

        with open(os.path.join(cache_path, cache_name), 'wb') as f:
            pickle.dump(sents_kg, f, protocol=4)
            print("Saved cache!")

        return sents_kg


def build_embed(sents, model, cache_path, cache_name):
    """
    build kg for sents
    :param sents:
    :param cache_path:
    :param cache_name:
    :return: built kg for sents
    """
    embeddings_1 = model.encode(sents, convert_to_tensor=True, show_progress_bar=True, device='cuda:0').cpu()

    print("Saving to pickle ...")
    if not os.path.exists(cache_path):
        os.makedirs(cache_path)

    with open(os.path.join(cache_path, cache_name), 'wb') as f:
        pickle.dump(embeddings_1, f, protocol=4)
        print("Saved cache!")

    return embeddings_1


def _compute_lm_toppk(sent1s, sent1s_emb, sent2s, sent2s_emb, topp, topk):
    """
    single proc version compute topp and topk
    :return: topp and topk result dicts
    """
    _temp_dicts = []

    cosine_matrix = util.pytorch_cos_sim(sent1s_emb, sent2s_emb).cpu().numpy()

    # filter out zeros, to speed up
    # cosine_matrix = np.around(cosine_matrix, decimals=3)
    sent2s = np.array(sent2s, dtype=str)

    top_k_indices = []
    top_k_scores = []
    for row in cosine_matrix:
        if len(row) <= topk:
            top_k_indices.append(np.array([i for i in range(len(row))]))
            top_k_scores.append(row)

        else:
            top_k_indices.append(np.argpartition(-row, topk)[:topk])  # fast nlargest
            top_k_scores.append(-np.partition(-row, topk)[:topk])

    for sent1_idx, sent1 in enumerate(sent1s):
        cur_topk_idx, cur_topk_sc = top_k_indices[sent1_idx], top_k_scores[sent1_idx]
        if len(cur_topk_idx) != 0 and len(cur_topk_sc) != 0:
            picked_sents = sent2s[cur_topk_idx].tolist()
            masked_ids = [idx for idx, sc in enumerate(cur_topk_sc) if sc > topp]
            picked_sents = [sent for idx, sent in enumerate(picked_sents) if idx in masked_ids]
            picked_scores = [sc for idx, sc in enumerate(cur_topk_sc) if idx in masked_ids]
            if picked_sents:
                _temp_dict = {'sent1': sent1, 'sent2': picked_sents, 'similarity_score': picked_scores}
                _temp_dicts.append(_temp_dict)

    return _temp_dicts


def compute_OEI(entities_1, entities_2, beta):
    """
    compute OEI of sent1 and sent2 with param beta
    :param beta:
    :return:
    """
    num_matches = len(list(set(entities_1) & set(entities_2)))
    precision = num_matches / len(entities_1) if len(entities_1) else 0
    recall = num_matches / len(entities_2) if len(entities_2) else 0

    if precision == 0 and recall == 0:
        return 0
    else:
        precision = num_matches / len(entities_1)
        recall = num_matches / len(entities_2)
        return (1 + beta ** 2) * precision * recall / ((beta ** 2) * precision + recall)


def _compute_lm_kg_toppk(sent1s, sent1s_emb, sent1s_ent, sent2s, sent2s_emb, sent2s_ent, topp, topk, beta):
    """
    single proc version compute topp and topk
    :return: topp and topk result dicts
    """
    _temp_dicts = []

    cosine_matrix = util.pytorch_cos_sim(sent1s_emb, sent2s_emb).cpu().numpy()

    # filter out zeros, to speed up
    cosine_matrix = np.around(cosine_matrix, decimals=3)
    sent2s = np.array(sent2s, dtype=str)

    # top_k_indices = [np.argpartition(-row, topk)[:topk] for row in cosine_matrix]  # fast nlargest
    # top_k_scores = [-np.partition(-row, topk)[:topk] for row in cosine_matrix]

    for sent1_idx, sent1 in enumerate(sent1s):
        if len(cosine_matrix[sent1_idx]) <= topk:
            picked_idx = np.array([i for i in range(len(cosine_matrix[sent1_idx]))])
        else:
            picked_idx = np.argpartition(-cosine_matrix[sent1_idx], topk)[:topk]  # fast nlargest
        picked_sent2s = sent2s[picked_idx].tolist()
        spice_scores = np.zeros(len(picked_sent2s))
        for sent2_idx, sent2 in enumerate(picked_sent2s):
            _sent1s_ents, _sent2s_ents = sent1s_ent[sent1_idx], sent2s_ent[sent2_idx]
            spice_scores[sent2_idx] = compute_OEI(_sent1s_ents, _sent2s_ents, beta)

        masked_ids = [idx for idx, sc in enumerate(list(spice_scores)) if sc > topp]
        picked_sents = [sent for idx, sent in enumerate(picked_sent2s) if idx in masked_ids]
        picked_scores = [sc for idx, sc in enumerate(list(spice_scores)) if idx in masked_ids]
        if picked_sents:
            _temp_dict = {'sent1': sent1, 'sent2': picked_sents, 'similarity_score': picked_scores}
            _temp_dicts.append(_temp_dict)

    return _temp_dicts


def pick_by_LM(sent1s, sent2s, topk, topp, batch_size, use_cache, cache_path, out_csv_name):
    """
    use LM embedding to pick parallels
    :param sent1s:
    :param sent2s:
    :param topk:
    :param topp:
    :param batch_size:
    :param use_cache:
    :param cache_path:
    :param out_csv_name:
    :return:
    """
    out_df = pd.DataFrame(columns=['sent1', 'sent2', 'similarity_score'])
    batch_size = min(batch_size, len(sent1s))
    model = SentenceTransformer('paraphrase-distilroberta-base-v1')

    if use_cache:
        if os.path.exists(os.path.join(cache_path, 'cached_allsides_emb_sent1s.pickle')):
            with open(os.path.join(cache_path, 'cached_allsides_emb_sent1s.pickle'), 'rb') as f:
                embeddings1 = pickle.load(f)
        else:
            print("You choose use_cache but cache does not exist. Now build for sent1s ...")
            embeddings1 = build_embed(sent1s, model, cache_path, 'cached_allsides_emb_sent1s.pickle')

        if os.path.exists(os.path.join(cache_path, 'cached_allsides_emb_sent2s.pickle')):
            with open(os.path.join(cache_path, 'cached_allsides_emb_sent2s.pickle'), 'rb') as f:
                embeddings2 = pickle.load(f)
        else:
            print("You choose use_cache but cache does not exist. Now build for sent2s ...")
            embeddings2 = build_embed(sent2s, model, cache_path, 'cached_allsides_emb_sent2s.pickle')
        print("Loaded")
    else:
        embeddings1 = build_embed(sent1s, model, cache_path, 'cached_allsides_emb_sent1s.pickle')
        embeddings2 = build_embed(sent2s, model, cache_path, 'cached_allsides_emb_sent2s.pickle')
        print("Build cache finished! In the future you can set use_cache to True.")

    _part_compute_emb_toppk = partial(_compute_lm_toppk, sent2s_emb=embeddings2, topp=topp, topk=topk)

    print("Now working on top p and top k picking ...")
    for batched_sent1s in tqdm(iterate_batches_lm(sent1s, embeddings1, batch_size),
                               total=len(sent1s) // batch_size):
        sent1s, sent1_embs = batched_sent1s
        _temp_dicts = _part_compute_emb_toppk(sent1s, sent1_embs, sent2s)
        out_df = out_df.append(pd.DataFrame(_temp_dicts), ignore_index=True)

    out_df.to_csv(os.path.join(cache_path, out_csv_name), index=False)
    return out_df


# def pick_by_LM(sent1s, sent1s_index, sent2s, sent2s_index, topk, topp, batch_size, use_cache, cache_path, out_csv_name):
#     """
#     use LM embedding to pick parallels
#     :param sent1s:
#     :param sent2s:
#     :param topk:
#     :param topp:
#     :param batch_size:
#     :param use_cache:
#     :param cache_path:
#     :param out_csv_name:
#     :return:
#     """
#     out_df = pd.DataFrame(columns=['sent1', 'sent2', 'similarity_score'])
#     batch_size = min(batch_size, len(sent1s))
#     model = SentenceTransformer('paraphrase-distilroberta-base-v1')
#
#     if use_cache:
#         if os.path.exists(os.path.join(cache_path, 'cached_allsides_emb_sent1s.pickle')):
#             with open(os.path.join(cache_path, 'cached_allsides_emb_sent1s.pickle'), 'rb') as f:
#                 embeddings1 = pickle.load(f)
#         else:
#             print("You choose use_cache but cache does not exist. Now build for sent1s ...")
#             embeddings1 = build_embed(sent1s, model, cache_path, 'cached_allsides_emb_sent1s.pickle')
#
#         if os.path.exists(os.path.join(cache_path, 'cached_allsides_emb_sent2s.pickle')):
#             with open(os.path.join(cache_path, 'cached_allsides_emb_sent2s.pickle'), 'rb') as f:
#                 embeddings2 = pickle.load(f)
#         else:
#             print("You choose use_cache but cache does not exist. Now build for sent2s ...")
#             embeddings2 = build_embed(sent2s, model, cache_path, 'cached_allsides_emb_sent2s.pickle')
#         print("Loaded")
#     else:
#         embeddings1 = build_embed(sent1s, model, cache_path, 'cached_allsides_emb_sent1s.pickle')
#         embeddings2 = build_embed(sent2s, model, cache_path, 'cached_allsides_emb_sent2s.pickle')
#         print("Build cache finished! In the future you can set use_cache to True.")
#
#     _part_compute_emb_toppk = partial(_compute_lm_toppk, topp=topp, topk=topk)
#
#     print("Now working on top p and top k picking ...")
#     for i in tqdm(range(len(sent1s_index) - 1), total=(len(sent1s_index) - 1)):
#         sent1s_story = sent1s[sent1s_index[i]:sent1s_index[i + 1]]
#         sent1_embs_story = embeddings1[sent1s_index[i]:sent1s_index[i + 1]]
#         sent2s_story = sent2s[sent2s_index[i]:sent2s_index[i + 1]]
#         sent2_embs_story = embeddings2[sent2s_index[i]:sent2s_index[i + 1]]
#
#         _temp_dicts = _part_compute_emb_toppk(sent1s_story, sent1_embs_story, sent2s_story, sent2_embs_story)
#         out_df = out_df.append(pd.DataFrame(_temp_dicts), ignore_index=True)
#
#     out_df.to_csv(os.path.join(cache_path, out_csv_name), index=False)
#
#     return out_df


# def pick_by_LM_kg(sent1s, sent1s_index, sent2s, sent2s_index, topk, topp, batch_size, use_cache, cache_path,
#                   out_csv_name):
#     """
#     use LM embedding to pick parallels
#     :param sent1s:
#     :param sent2s:
#     :param topk:
#     :param topp:
#     :param batch_size:
#     :param use_cache:
#     :param cache_path:
#     :param out_csv_name:
#     :return:
#     """
#     out_df = pd.DataFrame(columns=['sent1', 'sent2', 'similarity_score'])
#     batch_size = min(batch_size, len(sent1s))
#     model = SentenceTransformer('paraphrase-distilroberta-base-v1')
#
#     if use_cache:
#         # sents embeddings
#         if os.path.exists(os.path.join(cache_path, 'cached_allsides_emb_sent1s.pickle')):
#             with open(os.path.join(cache_path, 'cached_allsides_emb_sent1s.pickle'), 'rb') as f:
#                 embeddings1 = pickle.load(f)
#         else:
#             print("You choose use_cache but cache does not exist. Now build for sent1s ...")
#             embeddings1 = build_embed(sent1s, model, cache_path, 'cached_allsides_emb_sent1s.pickle')
#
#         if os.path.exists(os.path.join(cache_path, 'cached_allsides_emb_sent2s.pickle')):
#             with open(os.path.join(cache_path, 'cached_allsides_emb_sent2s.pickle'), 'rb') as f:
#                 embeddings2 = pickle.load(f)
#         else:
#             print("You choose use_cache but cache does not exist. Now build for sent2s ...")
#             embeddings2 = build_embed(sent2s, model, cache_path, 'cached_allsides_emb_sent2s.pickle')
#
#         # sents kgs
#         if os.path.exists(os.path.join(cache_path, 'cached_allsides_kg_sent1s.pickle')):
#             with open(os.path.join(cache_path, 'cached_allsides_kg_sent1s.pickle'), 'rb') as f:
#                 sent1s_kg = pickle.load(f)
#         else:
#             print("You choose use_cache but cache does not exist. Now build for sent1s ...")
#             sent1s_kg = build_kg(sent1s, cache_path, 'cached_allsides_kg_sent1s.pickle')
#
#         if os.path.exists(os.path.join(cache_path, 'cached_allsides_kg_sent2s.pickle')):
#             with open(os.path.join(cache_path, 'cached_allsides_kg_sent2s.pickle'), 'rb') as f:
#                 sent2s_kg = pickle.load(f)
#         else:
#             print("You choose use_cache but cache does not exist. Now build for sent2s ...")
#             sent2s_kg = build_kg(sent2s, cache_path, 'cached_allsides_kg_sent2s.pickle')
#
#         print("Loaded")
#     else:
#         embeddings1 = build_embed(sent1s, model, cache_path, 'cached_allsides_emb_sent1s.pickle')
#         embeddings2 = build_embed(sent2s, model, cache_path, 'cached_allsides_emb_sent2s.pickle')
#         sent1s_kg = build_kg(sent1s, cache_path, 'cached_allsides_kg_sent1s.pickle')
#         sent2s_kg = build_kg(sent2s, cache_path, 'cached_allsides_kg_sent2s.pickle')
#         print("Build cache finished! In the future you can set use_cache to True.")
#
#     _part_compute_emb_toppk = partial(_compute_lm_kg_toppk, topp=topp, topk=topk, beta=BETA)
#
#     print("Now working on top p and top k picking ...")
#     for i in tqdm(range(len(sent1s_index) - 1), total=(len(sent1s_index) - 1)):
#         sent1s_story = sent1s[sent1s_index[i]:sent1s_index[i + 1]]
#         sent1_embs_story = embeddings1[sent1s_index[i]:sent1s_index[i + 1]]
#         sent1_kg_story = sent1s_kg[sent1s_index[i]:sent1s_index[i + 1]]
#
#         sent2s_story = sent2s[sent2s_index[i]:sent2s_index[i + 1]]
#         sent2_embs_story = embeddings2[sent2s_index[i]:sent2s_index[i + 1]]
#         sent2_kg_story = sent2s_kg[sent2s_index[i]:sent2s_index[i + 1]]
#
#         _temp_dicts = _part_compute_emb_toppk(sent1s_story, sent1_embs_story, sent1_kg_story, sent2s_story,
#                                               sent2_embs_story, sent2_kg_story)
#         out_df = out_df.append(pd.DataFrame(_temp_dicts), ignore_index=True)
#
#     out_df.to_csv(os.path.join(cache_path, out_csv_name), index=False)
#
#     return out_df


def pick_by_LM_kg(sent1s, sent2s, topk, topp, batch_size, use_cache, cache_path,
                  out_csv_name):
    """
    use LM embedding to pick parallels
    :param sent1s:
    :param sent2s:
    :param topk:
    :param topp:
    :param batch_size:
    :param use_cache:
    :param cache_path:
    :param out_csv_name:
    :return:
    """
    out_df = pd.DataFrame(columns=['sent1', 'sent2', 'similarity_score'])
    batch_size = min(batch_size, len(sent1s))
    model = SentenceTransformer('paraphrase-distilroberta-base-v1')

    if use_cache:
        # sents embeddings
        if os.path.exists(os.path.join(cache_path, 'cached_allsides_emb_sent1s.pickle')):
            with open(os.path.join(cache_path, 'cached_allsides_emb_sent1s.pickle'), 'rb') as f:
                embeddings1 = pickle.load(f)
        else:
            print("You choose use_cache but cache does not exist. Now build for sent1s ...")
            embeddings1 = build_embed(sent1s, model, cache_path, 'cached_allsides_emb_sent1s.pickle')

        if os.path.exists(os.path.join(cache_path, 'cached_allsides_emb_sent2s.pickle')):
            with open(os.path.join(cache_path, 'cached_allsides_emb_sent2s.pickle'), 'rb') as f:
                embeddings2 = pickle.load(f)
        else:
            print("You choose use_cache but cache does not exist. Now build for sent2s ...")
            embeddings2 = build_embed(sent2s, model, cache_path, 'cached_allsides_emb_sent2s.pickle')

        # sents kgs
        if os.path.exists(os.path.join(cache_path, 'cached_allsides_kg_sent1s.pickle')):
            with open(os.path.join(cache_path, 'cached_allsides_kg_sent1s.pickle'), 'rb') as f:
                sent1s_kg = pickle.load(f)
        else:
            print("You choose use_cache but cache does not exist. Now build for sent1s ...")
            sent1s_kg = build_kg(sent1s, cache_path, 'cached_allsides_kg_sent1s.pickle')

        if os.path.exists(os.path.join(cache_path, 'cached_allsides_kg_sent2s.pickle')):
            with open(os.path.join(cache_path, 'cached_allsides_kg_sent2s.pickle'), 'rb') as f:
                sent2s_kg = pickle.load(f)
        else:
            print("You choose use_cache but cache does not exist. Now build for sent2s ...")
            sent2s_kg = build_kg(sent2s, cache_path, 'cached_allsides_kg_sent2s.pickle')

        print("Loaded")
    else:
        embeddings1 = build_embed(sent1s, model, cache_path, 'cached_allsides_emb_sent1s.pickle')
        embeddings2 = build_embed(sent2s, model, cache_path, 'cached_allsides_emb_sent2s.pickle')
        sent1s_kg = build_kg(sent1s, cache_path, 'cached_allsides_kg_sent1s.pickle')
        sent2s_kg = build_kg(sent2s, cache_path, 'cached_allsides_kg_sent2s.pickle')
        print("Build cache finished! In the future you can set use_cache to True.")

    _part_compute_emb_toppk = partial(_compute_lm_kg_toppk, sent2s_ent=sent2s_kg,
                                      sent2s_emb=embeddings2, topp=topp, topk=topk, beta=BETA)

    print("Now working on top p and top k picking ...")
    for batched_sent1s in tqdm(iterate_batches_lm_kg(sent1s, embeddings1, sent1s_kg, batch_size),
                               total=len(sent1s) // batch_size):
        sent1s, sent1_embs, sent1_ents = batched_sent1s
        _temp_dicts = _part_compute_emb_toppk(sent1s, sent1_embs, sent1_ents, sent2s)
        out_df = out_df.append(pd.DataFrame(_temp_dicts), ignore_index=True)

    out_df.to_csv(os.path.join(cache_path, out_csv_name), index=False)

    return out_df


def prepare_for_training(df, for_rl, output_path, output_root_name):
    """
    split the raw df file for train, eval and test
    :param df: the raw df
    :param for_rl: <bool> whether preparing for RL training
    :param output_path: the output folder
    :param output_root_name: the root name shared by train, eval, test
    :return:
    """

    df = df.sample(frac=1.0, random_state=RANDOM_SEED)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    if not for_rl:
        if 'similarity_score' in df.columns:
            del df['similarity_score']
        ungroup_df = df.explode('sent2', True)
    else:
        ungroup_df = df

    ungroup_df = ungroup_df[ungroup_df['sent1'] != ungroup_df['sent2']]
    train_df, evaluation_df = train_test_split(ungroup_df, test_size=0.1, random_state=RANDOM_SEED)
    # evaluation_df, test_df = train_test_split(_test_df, test_size=0.5, random_state=RANDOM_SEED)

    train_df.to_csv(str(output_path) + "/train_" + str(output_root_name) + ".csv", index=False)
    evaluation_df.to_csv(str(output_path) + "/eval_" + str(output_root_name) + ".csv", index=False)
    # test_df.to_csv(str(output_path) + "/test_" + str(output_root_name) + ".csv", index=False)


if __name__ == '__main__':
    # load the dataset
    config = CONFIGS['allsides_l2r']
    '''
    seq_pos_story_list [[story 1 sentences], [story 2 sentences], ... ]
    seq_pos_sentence [story 1 sentences, ... ]
    seq_pos_sentences_index [len(story 1 sentences), len(story 1 sentences) + len(story 2 sentences),]
    '''
    # for l2r
    seq_pos_story_list, seq_pos_sentence, seq_pos_sentences_index = get_allsides_data('./data/allsides/left_out/')
    seq_neg_story_list, seq_neg_sentence, seq_neg_sentences_index = get_allsides_data('./data/allsides/right_out/')

    # seq_pos_story_list, seq_pos_sentence, seq_pos_sentences_index = get_allsides_data('data/allsides/right_out/')
    # seq_neg_story_list, seq_neg_sentence, seq_neg_sentences_index = get_allsides_data('data/allsides/left_out/')

    # method 1: randomly choose
    # for each pos sentence randomly find n parallel neg sentences
    # random_df = pick_by_random(seq_pos_sentence, seq_neg_sentence, n=config.random_topk,
    #                            cache_path=config.cache_path,
    #                            out_csv_name=config.random_cache_file)
    #
    # prepare_for_training(random_df, False, config.t2t_cache_path, config.random_root_name)
    # prepare_for_training(random_df, True, config.rl_cache_path, config.random_root_name)

    # method 2: sentence similarity
    # first encode each sentence, and then find k most similarity ones
    # lm_df = pick_by_LM(seq_pos_sentence,
    #                    seq_neg_sentence,
    #                    topk=config.lm_topk, topp=config.lm_topp,
    #                    batch_size=config.lm_batch_size,
    #                    use_cache=True,
    #                    cache_path=config.cache_path,
    #                    out_csv_name=config.lm_cache_file)
    #
    # prepare_for_training(lm_df, False, config.t2t_cache_path, config.lm_root_name)
    # prepare_for_training(lm_df, True, config.rl_cache_path, config.lm_root_name)

    # method 3: lm + spice score
    # lm_kg_df = pick_by_LM_kg(seq_pos_sentence, seq_pos_sentences_index,
    #                          seq_neg_sentence, seq_neg_sentences_index,
    #                          topk=config.lm_kg_topk, topp=config.lm_kg_topp,
    #                          batch_size=config.lm_kg_batch_size,
    #                          use_cache=True,
    #                          cache_path=config.cache_path,
    #                          out_csv_name=config.lm_kg_cache_file)

    lm_kg_df = pick_by_LM_kg(seq_pos_sentence,
                             seq_neg_sentence,
                             topk=config.lm_kg_topk, topp=config.lm_kg_topp,
                             batch_size=config.lm_kg_batch_size,
                             use_cache=True,
                             cache_path=config.cache_path,
                             out_csv_name=config.lm_kg_cache_file)

    prepare_for_training(lm_kg_df, False, config.t2t_cache_path, config.lm_kg_root_name)
    prepare_for_training(lm_kg_df, True, config.rl_cache_path, config.lm_kg_root_name)
