#! /usr/bin/env python3
# coding=utf-8


import os
import re
import time
import re
import ast
import pickle
import sng_parser
import spacy
import random
import pandas as pd
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from sentence_transformers import SentenceTransformer, util
from functools import partial
from sklearn.model_selection import train_test_split
from nltk.tokenize import RegexpTokenizer
from parallel_configs import CONFIGS
from spacy.tokens import Token
# import stop words from language data
from spacy.lang.en.stop_words import STOP_WORDS

tokenizer = RegexpTokenizer(r'\w+')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

RANDOM_SEED = 25536

# theta_NE = 0.5
# theta_ENT = 0.3
# theta_REL = 0.2

# for yelp
BETA = 1

# replace to trf if speed ok
nlp = spacy.load("en_core_web_sm", disable=['ner', 'parser'])
# lemmatizer = nlp.get_pipe("lemmatizer")


def stop_words_getter(
    token): return token.is_stop or token.lower_ in STOP_WORDS or token.lemma_ in STOP_WORDS


# set attribute with getter
Token.set_extension('is_stop', getter=stop_words_getter)


def get_form_data(data_path):
    """
    load yelp dataset into list of sentences
    :param data_path:
    :return:
    """
    lines = open(data_path, 'r', encoding='utf-8').readlines()[:]
    features = []
    for i, line in enumerate(lines):
        seq = line.strip('\n')
        features.append(seq)
    return features


def build_embed(sents, model, cache_path, cache_name):
    """
    build kg for sents
    :param sents:
    :param cache_path:
    :param cache_name:
    :return: built kg for sents
    """
    embeddings_1 = model.encode(
        sents, convert_to_tensor=True, show_progress_bar=True).cpu()

    print("Saving to pickle ...")
    if not os.path.exists(cache_path):
        os.makedirs(cache_path)

    with open(os.path.join(cache_path, cache_name), 'wb') as f:
        pickle.dump(embeddings_1, f, protocol=4)
        print("Saved cache!")

    return embeddings_1


def return_entities(sent):
    return [token.lemma_.lower() for token in nlp(sent) if
            token.pos_ in ['PROPN', 'NOUN'] and not token.is_stop]


def build_ent(sents, cache_path, cache_name):
    """
    build kg for sents
    :param sents:
    :param cache_path:
    :param cache_name:
    :return: built kg for sents
    """
    with Pool(cpu_count()) as proc:  # save time from 30min to 3min (on 20 cpus machine 2)
        sents_kg = list(
            tqdm(proc.imap(return_entities, sents, ),
                 total=len(sents)))

        print("Saving to pickle ...")
        if not os.path.exists(cache_path):
            os.makedirs(cache_path)

        with open(os.path.join(cache_path, cache_name), 'wb') as f:
            pickle.dump(sents_kg, f, protocol=4)
            print("Saved cache!")

        return sents_kg


def iterate_sents_batches(sents, batch_size):
    """
    support iterate the sents (and its kgs) by batches
    :param sents:
    :param batch_size:
    :return:
    """
    assert isinstance(sents, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = sents[ofs * batch_size:(ofs + 1) * batch_size]
        if len(batch) <= 1:
            break
        yield batch
        ofs += 1


def iterate_batches_lm(sents, sents_embedding, batch_size):
    """
    support iterate the sents (and its embeddings) by batches
    :param sents:
    :param sents_embedding:
    :param batch_size:
    :return:
    """
    assert isinstance(sents, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = (sents[ofs * batch_size:(ofs + 1) * batch_size],
                 sents_embedding[ofs * batch_size:(ofs + 1) * batch_size])
        if len(batch[0]) <= 1 or len(batch[1]) <= 1:
            break
        yield batch
        ofs += 1


def iterate_batches_lm_kg(sents, sents_embedding, sents_ent, batch_size):
    """
    support iterate the sents (and its embeddings and kg) by batches
    :param sents:
    :param sents_embedding:
    :param sents_ent:
    :param batch_size:
    :return:
    """
    assert isinstance(sents, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = (sents[ofs * batch_size:(ofs + 1) * batch_size],
                 sents_embedding[ofs * batch_size:(ofs + 1) * batch_size],
                 sents_ent[ofs * batch_size:(ofs + 1) * batch_size])
        if len(batch[0]) <= 1 or len(batch[1]) <= 1 or len(batch[2]) <= 1:
            break
        yield batch
        ofs += 1


def get_yelp_data(data_path):
    """
    load yelp dataset into list of sentences
    :param data_path:
    :return:
    """
    lines = open(data_path, 'r', encoding='utf-8').readlines()[1:]
    features = []
    for i, line in enumerate(lines):
        seq = line.strip('\n')
        features.append(seq)
    return features


def compute_OEI(entities_1, entities_2, beta):
    """
    compute OEI of sent1 and sent2 with param beta
    :param beta:
    :return:
    """
    num_matches = len(list(set(entities_1) & set(entities_2)))
    precision = num_matches / len(entities_1) if len(entities_1) else 0
    recall = num_matches / len(entities_2) if len(entities_2) else 0

    if precision == 0 and recall == 0:
        return 0
    else:
        precision = num_matches / len(entities_1)
        recall = num_matches / len(entities_2)
        return (1 + beta ** 2) * precision * recall / ((beta ** 2) * precision + recall)


def _get_random(sent1, sent2s, n):
    temp_dict, _sent2s = {}, []
    temp_dict['sent1'] = sent1
    for i in [random.randint(0, len(sent2s) - 1) for _ in range(n)]:
        _sent2s.append(sent2s[i])
    temp_dict['sent2'] = _sent2s
    return temp_dict


def pick_by_random(sent1s, sent2s, n, cache_path, out_csv_name):
    """
    randomly choose n sents from sent2s as parallels of each sent1 in sent1s
    :param sent1s:
    :param sent2s:
    :param n:
    :param cache_path:
    :param out_csv_name:
    :return:
    """
    if not os.path.exists(cache_path):
        os.makedirs(cache_path)

    out_df = pd.DataFrame(columns=['sent1', 'sent2'])
    _part_get_random = partial(_get_random, n=n)
    args = [(sent1, sent2s) for sent1 in sent1s]

    with Pool(cpu_count()) as proc:
        _temp_dicts = list(
            tqdm(proc.starmap(_part_get_random, args, ),
                 total=len(sent1s)))
        out_df = out_df.append(pd.DataFrame(_temp_dicts), ignore_index=True)
    out_df.to_csv(os.path.join(cache_path, out_csv_name), index=False)

    return out_df


def _compute_lm_toppk(sent1s, sent1s_emb, sent2s, sent2s_emb, topp, topk):
    """
    single proc version compute topp and topk
    :return: topp and topk result dicts
    """
    _temp_dicts = []

    cosine_matrix = util.pytorch_cos_sim(sent1s_emb, sent2s_emb).cpu().numpy()

    # filter out zeros, to speed up
    cosine_matrix = np.around(cosine_matrix, decimals=3)
    sent2s = np.array(sent2s, dtype=str)

    top_k_indices = [np.argpartition(-row, topk)[:topk]
                     for row in cosine_matrix]  # fast nlargest
    top_k_scores = [-np.partition(-row, topk)[:topk] for row in cosine_matrix]

    for sent1_idx, sent1 in enumerate(sent1s):
        cur_topk_idx, cur_topk_sc = top_k_indices[sent1_idx], top_k_scores[sent1_idx]
        picked_sents = sent2s[cur_topk_idx].tolist()
        masked_ids = [idx for idx, sc in enumerate(cur_topk_sc) if sc > topp]
        picked_sents = [sent for idx, sent in enumerate(
            picked_sents) if idx in masked_ids]
        picked_scores = [sc for idx, sc in enumerate(
            cur_topk_sc) if idx in masked_ids]
        if picked_sents:
            _temp_dict = {'sent1': sent1, 'sent2': picked_sents,
                          'similarity_score': picked_scores}
            _temp_dicts.append(_temp_dict)

    return _temp_dicts


def _compute_lm_kg_toppk(sent1s, sent1s_emb, sent1s_ent, sent2s, sent2s_emb, sent2s_ent, topp, topk, beta):
    """
    single proc version compute topp and topk
    :return: topp and topk result dicts
    """
    _temp_dicts = []

    cosine_matrix = util.pytorch_cos_sim(sent1s_emb, sent2s_emb).cpu().numpy()

    # filter out zeros, to speed up
    cosine_matrix = np.around(cosine_matrix, decimals=3)
    sent2s = np.array(sent2s, dtype=str)

    top_k_indices = [np.argpartition(-row, topk)[:topk]
                     for row in cosine_matrix]  # fast nlargest
    # top_k_scores = [-np.partition(-row, topk)[:topk] for row in cosine_matrix]

    for sent1_idx, sent1 in enumerate(sent1s):
        picked_sent2s = sent2s[top_k_indices[sent1_idx]].tolist()
        spice_scores = np.zeros(len(picked_sent2s))
        for sent2_idx, sent2 in enumerate(picked_sent2s):
            _sent1s_ents, _sent2s_ents = sent1s_ent[sent1_idx], sent2s_ent[sent2_idx]
            spice_scores[sent2_idx] = compute_OEI(
                _sent1s_ents, _sent2s_ents, beta)

        masked_ids = [idx for idx, sc in enumerate(
            list(spice_scores)) if sc > topp]
        picked_sents = [sent for idx, sent in enumerate(
            picked_sent2s) if idx in masked_ids]
        picked_scores = [sc for idx, sc in enumerate(
            list(spice_scores)) if idx in masked_ids]
        if picked_sents:
            _temp_dict = {'sent1': sent1, 'sent2': picked_sents,
                          'similarity_score': picked_scores}
            _temp_dicts.append(_temp_dict)

    return _temp_dicts


def pick_by_LM(sent1s, sent2s, topk, topp, batch_size, use_cache, cache_path, out_csv_name):
    """
    use LM embedding to pick parallels
    :param sent1s:
    :param sent2s:
    :param topk:
    :param topp:
    :param batch_size:
    :param use_cache:
    :param cache_path:
    :param out_csv_name:
    :return:
    """
    out_df = pd.DataFrame(columns=['sent1', 'sent2', 'similarity_score'])
    batch_size = min(batch_size, len(sent1s))
    model = SentenceTransformer('paraphrase-distilroberta-base-v1')

    if use_cache:
        if os.path.exists(os.path.join(cache_path, 'cached_emb_sent1s.pickle')):
            with open(os.path.join(cache_path, 'cached_emb_sent1s.pickle'), 'rb') as f:
                embeddings1 = pickle.load(f)
        else:
            print(
                "You choose use_cache but cache does not exist. Now build for sent1s ...")
            embeddings1 = build_embed(
                sent1s, model, cache_path, 'cached_emb_sent1s.pickle')

        if os.path.exists(os.path.join(cache_path, 'cached_emb_sent2s.pickle')):
            with open(os.path.join(cache_path, 'cached_emb_sent2s.pickle'), 'rb') as f:
                embeddings2 = pickle.load(f)
        else:
            print(
                "You choose use_cache but cache does not exist. Now build for sent2s ...")
            embeddings2 = build_embed(
                sent2s, model, cache_path, 'cached_emb_sent2s.pickle')
        print("Loaded")
    else:
        embeddings1 = build_embed(
            sent1s, model, cache_path, 'cached_emb_sent1s.pickle')
        embeddings2 = build_embed(
            sent2s, model, cache_path, 'cached_emb_sent2s.pickle')
        print("Build cache finished! In the future you can set use_cache to True.")

    _part_compute_emb_toppk = partial(
        _compute_lm_toppk, sent2s_emb=embeddings2, topp=topp, topk=topk)

    print("Now working on top p and top k picking ...")
    for batched_sent1s in tqdm(iterate_batches_lm(sent1s, embeddings1, batch_size), total=len(sent1s) // batch_size):
        sent1s, sent1_embs = batched_sent1s
        _temp_dicts = _part_compute_emb_toppk(sent1s, sent1_embs, sent2s)
        out_df = out_df.append(pd.DataFrame(_temp_dicts), ignore_index=True)

    out_df.to_csv(os.path.join(cache_path, out_csv_name), index=False)

    return out_df


def pick_by_LM_kg(sent1s, sent2s, topk, topp, beta, batch_size, use_cache, cache_path, out_csv_name):
    """
    use LM embedding to pick parallels
    :param sent1s:
    :param sent2s:
    :param topk:
    :param topp:
    :param batch_size:
    :param use_cache:
    :param cache_path:
    :param out_csv_name:
    :return:
    """
    out_df = pd.DataFrame(columns=['sent1', 'sent2', 'similarity_score'])
    batch_size = min(batch_size, len(sent1s))
    model = SentenceTransformer('paraphrase-distilroberta-base-v1')

    if use_cache:
        # sents embeddings
        if os.path.exists(os.path.join(cache_path, 'cached_emb_sent1s.pickle')):
            with open(os.path.join(cache_path, 'cached_emb_sent1s.pickle'), 'rb') as f:
                embeddings1 = pickle.load(f)
        else:
            print(
                "You choose use_cache but cache does not exist. Now build for sent1s ...")
            embeddings1 = build_embed(
                sent1s, model, cache_path, 'cached_emb_sent1s.pickle')

        if os.path.exists(os.path.join(cache_path, 'cached_emb_sent2s.pickle')):
            with open(os.path.join(cache_path, 'cached_emb_sent2s.pickle'), 'rb') as f:
                embeddings2 = pickle.load(f)
        else:
            print(
                "You choose use_cache but cache does not exist. Now build for sent2s ...")
            embeddings2 = build_embed(
                sent2s, model, cache_path, 'cached_emb_sent2s.pickle')

        # sents kgs
        if os.path.exists(os.path.join(cache_path, 'cached_kg_sent1s.pickle')):
            with open(os.path.join(cache_path, 'cached_kg_sent1s.pickle'), 'rb') as f:
                sent1s_ent = pickle.load(f)
        else:
            print(
                "You choose use_cache but cache does not exist. Now build for sent1s ...")
            sent1s_ent = build_ent(
                sent1s, cache_path, 'cached_kg_sent1s.pickle')

        if os.path.exists(os.path.join(cache_path, 'cached_kg_sent2s.pickle')):
            with open(os.path.join(cache_path, 'cached_kg_sent2s.pickle'), 'rb') as f:
                sent2s_ent = pickle.load(f)
        else:
            print(
                "You choose use_cache but cache does not exist. Now build for sent2s ...")
            sent2s_ent = build_ent(
                sent2s, cache_path, 'cached_kg_sent2s.pickle')

        print("Loaded")
    else:
        embeddings1 = build_embed(
            sent1s, model, cache_path, 'cached_emb_sent1s.pickle')
        embeddings2 = build_embed(
            sent2s, model, cache_path, 'cached_emb_sent2s.pickle')
        sent1s_ent = build_ent(sent1s, cache_path, 'cached_kg_sent1s.pickle')
        sent2s_ent = build_ent(sent2s, cache_path, 'cached_kg_sent2s.pickle')
        print("Build cache finished! In the future you can set use_cache to True.")

    _part_compute_emb_toppk = partial(_compute_lm_kg_toppk, sent2s_ent=sent2s_ent,
                                      sent2s_emb=embeddings2, topp=topp, topk=topk, beta=beta)

    print("Now working on top p and top k picking ...")
    for batched_sent1s in tqdm(iterate_batches_lm_kg(sent1s, embeddings1, sent1s_ent, batch_size),
                               total=len(sent1s) // batch_size):
        sent1s, sent1_embs, sent1_ents = batched_sent1s
        _temp_dicts = _part_compute_emb_toppk(
            sent1s, sent1_embs, sent1_ents, sent2s)
        out_df = out_df.append(pd.DataFrame(_temp_dicts), ignore_index=True)

    out_df.to_csv(os.path.join(cache_path, out_csv_name), index=False)

    return out_df


def prepare_for_training(df, for_rl, output_path, output_root_name):
    """
    split the raw df file for train, eval and test
    :param df: the raw df
    :param for_rl: <bool> whether preparing for RL training
    :param output_path: the output folder
    :param output_root_name: the root name shared by train, eval, test
    :return:
    """

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    if not for_rl:
        if 'similarity_score' in df.columns:
            del df['similarity_score']
        ungroup_df = df.explode('sent2', True)
    else:
        ungroup_df = df

    train_df, evaluation_df = train_test_split(
        ungroup_df, test_size=0.2, random_state=RANDOM_SEED)
    # evaluation_df, test_df = train_test_split(_test_df, test_size=0.5, random_state=RANDOM_SEED)

    train_df.to_csv(str(output_path) + "/train_" +
                    str(output_root_name) + ".csv", index=False)
    evaluation_df.to_csv(str(output_path) + "/eval_" +
                         str(output_root_name) + ".csv", index=False)
    # test_df.to_csv(str(output_path) + "/test_" + str(output_root_name) + ".csv", index=False)


if __name__ == '__main__':

    # choose the config
    config = CONFIGS['yelp_pos2neg']
    # load dataset
    # sent1_list = get_form_data('data/GYAFC_Corpus/Family_Relationships/train/formal')
    # sent2_list = get_form_data('data/GYAFC_Corpus/Family_Relationships/train/informal')

    # load dataset
    sent1_list = get_yelp_data(config.pos_file_name)
    sent2_list = get_yelp_data(config.neg_file_name)

    # method 1: randomly choose
    # for each pos sentence randomly find n parallel neg sentences
    random_df = pick_by_random(sent1_list, sent2_list, n=config.random_topk,
                               cache_path=config.cache_path,
                               out_csv_name=config.random_cache_file)

    prepare_for_training(
        random_df, False, config.t2t_cache_path, config.random_root_name)
    prepare_for_training(
        random_df, True, config.rl_cache_path, config.random_root_name)

    # method 2: sentence similarity
    # first encode each sentence, and then find k most similarity ones
    lm_df = pick_by_LM(sent1_list, sent2_list,
                       topk=config.lm_topk, topp=config.lm_topp,
                       batch_size=config.lm_batch_size,
                       use_cache=True,
                       cache_path=config.cache_path,
                       out_csv_name=config.lm_cache_file)

    prepare_for_training(
        lm_df, False, config.t2t_cache_path, config.lm_root_name)
    prepare_for_training(
        lm_df, True, config.rl_cache_path, config.lm_root_name)

    # method 3: lm + spice score
    lm_kg_df = pick_by_LM_kg(sent1_list, sent2_list,
                             topk=config.lm_kg_topk, topp=config.lm_kg_topp,
                             beta=config.beta,
                             batch_size=config.lm_kg_batch_size,
                             use_cache=True,
                             cache_path=config.cache_path,
                             out_csv_name=config.lm_kg_cache_file)

    prepare_for_training(
        lm_kg_df, False, config.t2t_cache_path, config.lm_kg_root_name)
    prepare_for_training(
        lm_kg_df, True, config.rl_cache_path, config.lm_kg_root_name)
