#! /usr/bin/env python3
# coding=utf-8


import os
import spacy
from collections import Counter
import random
import pickle
import pandas as pd
import numpy as np
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from sentence_transformers import SentenceTransformer, util
from functools import partial
from sklearn.model_selection import train_test_split
from nltk.tokenize import RegexpTokenizer
from parallel_configs import CONFIGS

from spacy.tokens import Token
from spacy.lang.en.stop_words import STOP_WORDS  # import stop words from language data

tokenizer = RegexpTokenizer(r'\w+')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

RANDOM_SEED = 25536

# for yelp
BETA = 1

batch_size = 20

# replace to trf if speed ok
nlp = spacy.load("en_core_web_sm")
lemmatizer = nlp.get_pipe("lemmatizer")

stop_words_getter = lambda token: token.is_stop or token.lower_ in STOP_WORDS or token.lemma_ in STOP_WORDS
Token.set_extension('is_stop', getter=stop_words_getter)  # set attribute with getter


def iterate_batches(sent1s, sent1_kgs, sent2s, sent2_kgs, batch_size):
    """
    support iterate the sents (and its kgs) by batches
    :param sent1s:
    :param sent2s:
    :param batch_size:
    :return:
    """
    assert isinstance(sent1s, list)
    assert isinstance(sent2s, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = (sent1s[ofs * batch_size:(ofs + 1) * batch_size],
                 sent1_kgs[ofs * batch_size:(ofs + 1) * batch_size],
                 sent2s[ofs * batch_size:(ofs + 1) * batch_size],
                 sent2_kgs[ofs * batch_size:(ofs + 1) * batch_size])
        if len(batch[0]) <= 1 or len(batch[1]) <= 1 or len(batch[2]) <= 1 or len(batch[3]) <= 1:
            print("break here")
            break
        yield ofs, batch
        ofs += 1


def return_entities(sent):
    return [token.lemma_.lower() for token in nlp(sent) if
               token.pos_ in ['PROPN', 'NOUN'] and not token.is_stop]


def build_ent(sents, cache_path, cache_name):
    """
    build kg for sents
    :param sents:
    :param cache_path:
    :param cache_name:
    :return: built kg for sents
    """
    with Pool(cpu_count()) as proc:  # save time from 30min to 3min (on 20 cpus machine 2)
        sents_kg = list(
            tqdm(proc.imap(return_entities, sents, ),
                 total=len(sents)))

        print("Saving to pickle ...")
        if not os.path.exists(cache_path):
            os.makedirs(cache_path)

        with open(os.path.join(cache_path, cache_name), 'wb') as f:
            pickle.dump(sents_kg, f, protocol=4)
            print("Saved cache!")

        return sents_kg


def get_entities(text_list):
    ent_counter = Counter()
    for text in text_list:
        ent_counter.update([token.lemma_.lower() for token in nlp(text) if
                      token.pos_ in ['PROPN', 'NOUN'] and not token.is_stop])

    return ent_counter


def compute_OEI(sent1, sent1_ents, sent2, sent2_ents, beta):
    """
    compute OEI of sent1 and sent2 with param beta
    :param sent1_ents:
    :param sent2_ents:
    :param beta:
    :return:
    """
    _OEIs, len_1, len_2, total_1, total_2, count_1, count_2, matches = [], [], [], [], [], 0, 0, []
    for s1, sent1_ent, s2, sent2_ent in zip(sent1, sent1_ents, sent2, sent2_ents):
        total_1.append(len(list(set(sent1_ent))))
        total_2.append(len(list(set(sent2_ent))))
        count_1 += len(list(set(sent1_ent)))
        count_2 += len(list(set(sent2_ent)))
        len_1.append(len(s1.split(' ')))
        len_2.append(len(s2.split(' ')))

        num_matches = len(list(set(sent1_ent) & set(sent2_ent)))
        matches.append(num_matches)

        precision = num_matches / len(sent1_ent) if len(sent1_ent) else 0
        recall = num_matches / len(sent2_ent) if len(sent2_ent) else 0

        if precision == 0 and recall == 0:
            _OEIs.append(0)
        else:
            _OEIs.append((1 + beta ** 2) * precision * recall / ((beta ** 2) * precision + recall))

    return np.mean(_OEIs), np.mean(len_1), np.mean(len_2), np.mean(total_1), np.mean(total_2), count_1, count_2, np.mean(matches)


if __name__ == '__main__':

    use_cache = False

    cache_path = './.cache/temp/'
    test_df = pd.read_csv("./data/formal_family/t2t/f2i/formal_test.csv", header=0)
    sent1s, sent2s = test_df['sent1'].tolist(), test_df['sent2_d'].tolist()

    if os.path.exists(os.path.join(cache_path, 'cached_kg_sent1s.pickle')) and use_cache:
        with open(os.path.join(cache_path, 'cached_kg_sent1s.pickle'), 'rb') as f:
            sent1s_ent = pickle.load(f)
    else:
        sent1s_ent = build_ent(sent1s, cache_path, 'cached_kg_sent1s.pickle')

    if os.path.exists(os.path.join(cache_path, 'cached_kg_sent2s.pickle')) and use_cache:
        with open(os.path.join(cache_path, 'cached_kg_sent2s.pickle'), 'rb') as f:
            sent2s_ent = pickle.load(f)
    else:
        sent2s_ent = build_ent(sent2s, cache_path, 'cached_kg_sent2s.pickle')

    min_len = min(len(sent1s_ent), len(sent2s_ent))

    OEIs = np.zeros((min_len // batch_size) + 1)
    LEN_1 = np.zeros((min_len // batch_size) + 1)
    LEN_2 = np.zeros((min_len // batch_size) + 1)
    TOT_1s = np.zeros((min_len // batch_size) + 1)
    TOT_2s = np.zeros((min_len // batch_size) + 1)
    CNT_1s = np.zeros((min_len // batch_size) + 1)
    CNT_2s = np.zeros((min_len // batch_size) + 1)
    MAT = np.zeros((min_len // batch_size) + 1)

    for batch_count, ents in iterate_batches(sent1s, sent1s_ent, sent2s, sent2s_ent, batch_size):
        sent1, sent1_ent, sent2, sent2_ent = ents
        OEIs[batch_count], LEN_1[batch_count], LEN_2[batch_count], TOT_1s[batch_count], TOT_2s[batch_count], CNT_1s[batch_count], CNT_2s[batch_count], MAT[batch_count] = compute_OEI(sent1, sent1_ent, sent2, sent2_ent, BETA)

    # OEI
    print("Avg OEI:", np.mean(OEIs))

    # sent length
    print("Avg LEN 1:", np.mean(LEN_1))
    print("Avg LEN 2:", np.mean(LEN_2))

    # entitity per sent
    print("Avg TOT_1:", np.mean(TOT_1s))
    print("Avg TOT_2:", np.mean(TOT_2s))

    # total entity count
    print("Avg CNT_1:", np.sum(CNT_1s))
    print("Avg CNT_2:", np.sum(CNT_2s))

    # avg overlaps
    print("Avg MAT:", np.mean(MAT))
