import string
from collections import Counter
from itertools import groupby
from operator import itemgetter

import spacy

nlp = spacy.load("en_core_web_sm")
from nltk.corpus import wordnet as wn

from util.spelling_corrector import SpellingCorrector

spellCorrector = SpellingCorrector()

syns_cache = {}
fry_1000_words = set()
with open('util/words_db/dolch-fry-1000-words.txt', 'r') as content_file:
    fry_1000_words = set(w for w in content_file.read().split() if w != '')


def extract_features(doc, flatten_features=True):
    # ////////////////// Lexical /////////////////
    total_chars = len(doc)  # LF1

    avg_letters = 0  # LF2
    avg_upper_cases = 0  # LF3
    avg_digits = 0  # LF4
    avg_ws = 0  # LF5
    avg_tabs = 0  # LF6

    for i in range(len(doc)):
        if (doc[i].isalpha()):
            avg_letters += 1
        elif (doc[i].isupper()):
            avg_upper_cases += 1
        elif (doc[i].isdigit()):
            avg_digits += 1
        elif (doc[i].isspace()):
            avg_ws += 1
        elif (doc[i] == '\t'):
            avg_tabs += 1

    avg_letters /= total_chars
    avg_upper_cases /= total_chars
    avg_digits /= total_chars
    avg_ws /= total_chars
    avg_tabs /= total_chars

    total_longer_words = 0
    longer_words = {}
    total_short_words = 0
    count_by_word = {}
    count_by_word_length = {}
    words = [w.strip("0123456789!:,.?(){}[]") for w in doc.split() if len(w) > 0]
    avg_corrections = 0
    avg_words_with_3_or_more_syllables = 0
    avg_syllables = 0
    total_numeric_exprs = 0  # LF28
    total_fry_words_count = 0  # LF31
    total_negations = 0  # LF34
    personal_pronouns_count = 0  # STF16,1
    demonstrative_words_count = 0  # STF16,2
    relative_words_count = 0  # STF16,3
    indefinete_words_count = 0  # STF16,4
    reflexive_words_count = 0  # STF16,6
    two_words_prepositions_count = 0  # STF16,8
    total_shorten_forms = 0  # STF15
    total_interrogatives_count = 0  # STF16,5
    total_articles = 0  # STF17

    for i, word in enumerate(words):
        if word in count_by_word:
            count_by_word[word] += 1
        else:
            count_by_word[word] = 1

        word_len = len(word)
        if word_len in count_by_word_length:
            count_by_word_length[word_len] += 1
        else:
            count_by_word_length[word_len] = 1

        if word_len > 8:
            longer_words[word] = word_len

        if word_len > 6:
            total_longer_words += 1
        elif word_len >= 1 and word_len <= 3:
            total_short_words += 1
        # spell corrector
        if word != "" and (spellCorrector.correction(word) != word):
            avg_corrections += 1
        # syllable
        syllable_count_for_word = syllable_count(word)
        if syllable_count_for_word > 2:
            avg_words_with_3_or_more_syllables += 1
        avg_syllables += syllable_count_for_word
        # numeric expressions
        if word.isdigit():
            total_numeric_exprs += 1
        # fry-words
        if word in fry_1000_words:
            total_fry_words_count += 1

        lowered_word = word.lower()
        if lowered_word == "not" or lowered_word == "no":
            total_negations += 1

        if lowered_word in ["i", "me", "my", "mine", "we", "us", "our", "ours", "you", "yours", "your", "he", "him",
                            "his", "they", "them", "their", "theirs", "she", "her", "hers", "it", "its"]:
            personal_pronouns_count += 1

        if lowered_word in ["this", "these", "that", "those", "former", "latter"]:
            demonstrative_words_count += 1

        if lowered_word in ["who", "whom", "whose", "what", "which", "that"]:
            relative_words_count += 1

        if lowered_word in ["one", "one's", "oneself", "something", "anything", "nothing", "someone", "anyone", "noone",
                            "somebody", "anybody", "nobody"]:
            indefinete_words_count += 1

        if lowered_word in ["self", "selves"]:
            reflexive_words_count += 1

        if i + 1 < len(words) and (lowered_word + " " + words[i + 1].lower()) in ["ahead of", "except for",
                                                                                  "instead of", "owing to",
                                                                                  "apart from", "in addition to",
                                                                                  "near to", "such as", "as for",
                                                                                  "in front of", "on account of",
                                                                                  "thanks to", "as well as",
                                                                                  "in place of", "on top of", "up to",
                                                                                  "because of", "in spite of", "out of",
                                                                                  "due to", "inside of", "outside of"]:
            two_words_prepositions_count += 1

        if lowered_word in ["i'm", "he's", "she's", "it's", "we're", "you're", "they're", "i've", "we've", "you've",
                            "they've", "he's", "she's", "it's", "i'd", "he'd", "she'd", "we'd", "you'd", "they'd",
                            "i'll", "he'll", "she'll", "we'll", "you'll", "they'll", "i'd", "he'd", "she'd", "we'd",
                            "you'd", "they'd", "who's", "who're", "who's", "who've", "who'd", "who'll", "who'd",
                            "what's", "what're", "what's", "what've", "what'd", "what'll", "what'd", "there's",
                            "there're", "there's", "there've", "there'd", "there'll", "there'd", "where's", "where're",
                            "where's", "where've", "where'd", "where'll", "where'd", "why's", "why're", "why's",
                            "why've", "why'd", "why'll", "why'd", "whom's", "whom're", "whom's", "whom've", "whom'd",
                            "whom'll", "whom'd", "these're", "these've", "these'd", "these'll", "these'd", "those're",
                            "those've", "those'd", "those'll", "those'd", "that's", "that're", "that's", "that've",
                            "that'd", "that'll", "that'd", "this's", "this's", "this'd", "this'll", "this'd", "isn't",
                            "aren't", "wasn't", "weren't", "hasn't", "haven't", "hadn't", "don't", "doesn't", "didn't",
                            "can't", "couldn't", "won't", "wouldn't", "shouldn't", "mustn't", "needn't"]:
            total_shorten_forms += 1

        if lowered_word in ["who", "whose", "whom", "which", "what", "how", "why", "when", "where"]:
            total_interrogatives_count += 1
        if lowered_word in ["the", "a", "an"]:
            total_articles += 1

    top_20_longer_words = sorted(longer_words.items(), key=itemgetter(1), reverse=True)[:20]  # LF11
    total_words = len(words)  # LF7
    mean_word_len = total_chars / total_words  # LF8
    unique_words = len(count_by_word.keys())  # LF9
    total_longer_words /= total_words  # LF10
    total_short_words /= total_words  # LF12

    hapax_legomenon = [i for i in count_by_word.keys() if count_by_word[i] == 1]  # LF13
    hapax_dislegomenon = [i for i in count_by_word.keys() if count_by_word[i] == 2]  # LF14

    M1 = float(len(count_by_word))
    M2 = sum([len(list(g)) * (freq ** 2) for freq, g in groupby(sorted(count_by_word.values()))])

    try:
        yulesK = (M1 * M1) / (M2 - M1)  # LF15
    except ZeroDivisionError:
        yulesK = 0

    a = -0.172
    brunet_w = total_words ** (unique_words ** -a)  # LF16

    word_len_freq_top_10 = dict(sorted(count_by_word_length.items(), key=itemgetter(0), reverse=True)[:10])  # LF18
    avg_corrections /= total_words  # LF20

    tagged_doc = nlp(doc)

    avg_words_with_3_or_more_syllables /= total_words  # LF22
    avg_syllables /= total_words  # LF23

    total_polysemous_words = 0  # LF24
    checked_words = set()
    total_content_words = 0  # LF21
    total_passive_verbs = 0  # LF32
    total_passive_sents = 0  # LF33
    total_conjunctions = 0  # LF37
    total_prepositions = 0  # STF16,7
    total_cordinative_conj = 0  # STF16,9
    total_correlative_conj = 0  # STF16,10

    mean_verbs_per_sentence = 0  # SF10,1
    mean_prepositions_per_sentence = 0  # SF10,2
    total_quantifiers = 0  # STF16,11

    for token in tagged_doc:
        if token.pos_ == 'NOUN' or token.pos_ == 'ADJ' or token.pos_ == 'ADV' or token.pos_ == 'VERB':
            total_content_words += 1

        if token.tag_ == 'VBN':
            total_passive_verbs += 1

        if token.dep_ == 'ROOT' and token.tag_ == 'VBN':
            total_passive_sents += 1

        if token.pos_ == 'CCONJ' or token.dep_ == 'cc':
            total_conjunctions += 1

        if token.pos_ == 'VERB':
            mean_verbs_per_sentence += 1

        if token.tag_ == 'IN':
            mean_prepositions_per_sentence += 1

        if token.dep_ == 'prep':
            total_prepositions += 1

        if token.dep_ == 'cc':
            total_cordinative_conj += 1

        if token.dep_ == 'conj':
            total_correlative_conj += 1

        if token.dep_ == 'quantmod':
            total_quantifiers += 1

        # synset implementation
        if token.text == string.punctuation or token.text == ",":
            continue
        if token.lemma_ != "-PRON-":
            # Avoid duplicate words
            if token.lemma_ in checked_words:
                continue
            checked_words.add(token.lemma_)

            # Check for syn-cache
            if token.lemma_ not in syns_cache:
                syns = set([lemms.name().lower() for syn in wn.synsets(token.lemma_) for lemms in syn.lemmas()])
                syns.discard(token.lemma_)
                syns_cache[token.lemma_] = syns
            else:
                syns = syns_cache[token.lemma_]

            if len(syns) > 1:
                total_polysemous_words += 1

    if total_content_words > 0:
        polysemous_words_cont_words_ratio = total_polysemous_words / total_content_words  # LF25
    else:
        polysemous_words_cont_words_ratio = total_polysemous_words

    type_token_ratio = total_content_words / total_words  # LF26

    avg_negations = total_negations / total_words  # LF35
    lemma_diversity = len(checked_words) / total_words  # LF38

    bi_single_occurences = []  # LF39
    tri_single_occurences = []  # LF40
    bi_occurences = Counter(n_grams(doc, 2))
    tri_occurences = Counter(n_grams(doc, 3))
    four_occurences = Counter(n_grams(doc, 4))

    for i in bi_occurences.elements():
        if bi_occurences[i] == 1:
            bi_single_occurences.append(i[0] + ',' + i[1])

    for i in tri_occurences.elements():
        if tri_occurences[i] == 1:
            tri_single_occurences.append(i[0] + ',' + i[1] + ',' + i[2])

    # /////////////// Syntactic /////////////////////
    total_single_quotes = 0  # SF1
    total_commas = 0  # SF2
    total_periods = 0  # SF3
    total_colons = 0  # SF4
    total_semi_colons = 0  # SF5
    total_question_mark = 0  # SF6
    total_exclamanation_mark = 0  # SF7
    total_ellipsis = 0  # SF8
    total_special_chars = 0  # SF9
    for i in range(len(doc)):
        if doc[i] == "'":
            total_single_quotes += 1
        elif doc[i] == ',':
            total_commas += 1
        elif doc[i] == '.':
            total_periods += 1
        elif doc[i] == ':':
            total_colons += 1
        elif doc[i] == ';':
            total_semi_colons += 1
        elif doc[i] == '?':
            total_question_mark += 1
        elif doc[i] == '!':
            total_exclamanation_mark += 1
        elif doc[i:i + 3] == '...':
            total_ellipsis += 1

        if doc[i] in string.punctuation:
            total_special_chars += 1

    mean_dependents_clauses_length = 0  # SF10,3
    mean_phrases_length = 0  # SF10,4
    mean_phrases_per_sentence = 0  # SF10,5
    mean_dependents_clauses_per_sentence = 0  # SF10,6

    for chunk in tagged_doc.noun_chunks:
        if chunk.root.dep_ == 'dobj':
            mean_dependents_clauses_length += len(chunk.text)
            mean_dependents_clauses_per_sentence += 1
        else:
            mean_phrases_length += len(chunk.text)
            mean_phrases_per_sentence += 1

    total_sentences = len(list(tagged_doc.sents))  # STF2
    mean_verbs_per_sentence /= total_sentences

    if mean_phrases_per_sentence > 0:
        mean_phrases_length /= mean_phrases_per_sentence

    if mean_dependents_clauses_per_sentence > 0:
        mean_dependents_clauses_length /= mean_dependents_clauses_per_sentence

    mean_phrases_per_sentence /= total_sentences
    mean_dependents_clauses_per_sentence /= total_sentences

    mean_bi_gram_per_sent = len(set(bi_occurences.elements())) / total_sentences  # SF11, 1
    mean_tri_gram_per_sent = len(set(tri_occurences.elements())) / total_sentences  # SF11, 2
    mean_four_gram_per_sent = len(set(four_occurences.elements())) / total_sentences  # SF11, 3

    mean_len_of_sentence = len(doc) / total_sentences  # SF11, 4

    # /////////////////////// Structural /////////////////////////////
    total_lines = len(doc.splitlines())  # STF1
    total_long_sentences = 0  # STF3
    total_sents_with_ucase_start = 0  # STF9
    total_sents_with_lcase_start = 0  # STF10

    total_subj_verb_obj_sents = 0
    total_subj_obj_verb_sents = 0
    total_obj_subj_verb_sents = 0
    total_obj_verb_subj_sents = 0
    total_verb_subj_obj_sents = 0
    total_verb_obj_subj_sents = 0

    for sent in tagged_doc.sents:
        if len(sent.text.split()) > 15:
            total_long_sentences += 1
        if sent.text[0].isupper():
            total_sents_with_ucase_start += 1
        else:
            total_sents_with_lcase_start += 1

        # check sentence category
        subj_index = -1
        verb_index = -1
        obj_index = -1
        for i, token in enumerate(sent):
            if token.dep_ in ["nsubj", "nsubjpass", "csubj", "csubjpass"]:
                subj_index = i
            elif token.dep_ == "ROOT":
                verb_index = i
            elif token.dep_ in ["pobj", "obj", "dobj", "iobj"]:
                obj_index = i

        if subj_index != -1 and verb_index != -1 and obj_index != -1:
            if subj_index < verb_index and verb_index < obj_index:
                total_subj_verb_obj_sents += 1
            elif subj_index < obj_index and obj_index < verb_index:
                total_subj_obj_verb_sents += 1
            elif obj_index < subj_index and subj_index < verb_index:
                total_obj_subj_verb_sents += 1
            elif obj_index < verb_index and verb_index < subj_index:
                total_obj_verb_subj_sents += 1
            elif verb_index < subj_index and subj_index < obj_index:
                total_verb_subj_obj_sents += 1
            elif verb_index < obj_index and obj_index < subj_index:
                total_verb_obj_subj_sents += 1
        else:
            pass

    total_paragraphs = doc.count("\n\n")  # STF4
    avg_sents_per_paragraph = total_long_sentences / max(1, total_paragraphs)  # STF5
    avg_words_per_paragraph = total_words / max(1, total_paragraphs)  # STF6
    avg_chars_per_paragraph = total_chars / max(1, total_paragraphs)  # STF7
    avg_words_per_sentence = total_words / max(1, total_sentences)  # STF8
    total_empty_lines_ratio = doc.count("\n") / total_lines  # STF11
    avg_len_of_line = total_chars / total_lines  # STF12

    features_dict = {"total_chars": total_chars,
                     "avg_letters": avg_letters,
                     "avg_upper_cases": avg_upper_cases,
                     "avg_digits": avg_digits,
                     "avg_ws": avg_ws,
                     "avg_tabs": avg_tabs,
                     "total_words": total_words,
                     "mean_word_len": mean_word_len,
                     "unique_words": unique_words,
                     "total_longer_words": total_longer_words,
                     # "longer_words": longer_words,
                     "total_short_words": total_short_words,
                     # "hapax_legomenon": hapax_legomenon,
                     # "hapax_dislegomenon": hapax_dislegomenon,
                     "yulesK": yulesK,
                     "brunet_w": brunet_w,
                     # "word_len_freq_top_10": word_len_freq_top_10,
                     "avg_corrections": avg_corrections,
                     "total_content_words": total_content_words,
                     "avg_words_with_3_or_more_syllables": avg_words_with_3_or_more_syllables,
                     "avg_syllables": avg_syllables,
                     "total_polysemous_words": total_polysemous_words,
                     "polysemous_words_cont_words_ratio": polysemous_words_cont_words_ratio,
                     "type_token_ratio": type_token_ratio,
                     "total_numeric_exprs": total_numeric_exprs,
                     "total_fry_words_count": total_fry_words_count,
                     "total_passive_verbs": total_passive_verbs,
                     "total_passive_sents": total_passive_sents,
                     "total_negations": total_negations,
                     "avg_negations": avg_negations,
                     "total_conjunctions": total_conjunctions,
                     "lemma_diversity": lemma_diversity,
                     # "bi_single_occurences": bi_single_occurences,
                     # "tri_single_occurences": tri_single_occurences,
                     # Syntactic
                     "total_single_quotes": total_single_quotes,
                     "total_commas": total_commas,
                     "total_periods": total_periods,
                     "total_colons": total_colons,
                     "total_semi_colons": total_semi_colons,
                     "total_question_mark": total_question_mark,
                     "total_exclamanation_mark": total_exclamanation_mark,
                     "total_ellipsis": total_ellipsis,
                     "total_special_chars": total_special_chars,
                     "mean_verbs_per_sentence": mean_verbs_per_sentence,
                     "mean_prepositions_per_sentence": mean_prepositions_per_sentence,
                     "mean_dependents_clauses_length": mean_dependents_clauses_length,
                     "mean_phrases_length": mean_phrases_length,
                     "mean_phrases_per_sentence": mean_phrases_per_sentence,
                     "mean_dependents_clauses_per_sentence": mean_dependents_clauses_per_sentence,
                     "mean_bi_gram_per_sent": mean_bi_gram_per_sent,
                     "mean_tri_gram_per_sent": mean_tri_gram_per_sent,
                     "mean_four_gram_per_sent": mean_four_gram_per_sent,
                     "mean_len_of_sentence": mean_len_of_sentence,
                     "total_lines": total_lines,
                     "total_sentences": total_sentences,
                     "total_long_sentences": total_long_sentences,
                     "total_paragraphs": total_paragraphs,
                     "avg_sents_per_paragraph": avg_sents_per_paragraph,
                     "avg_words_per_paragraph": avg_words_per_paragraph,
                     "avg_chars_per_paragraph": avg_chars_per_paragraph,
                     "avg_words_per_sentence": avg_words_per_sentence,
                     "total_empty_lines_ratio": total_empty_lines_ratio,
                     "avg_len_of_line": avg_len_of_line,
                     "total_shorten_forms": total_shorten_forms,
                     "personal_pronouns_count": personal_pronouns_count,
                     "demonstrative_words_count": demonstrative_words_count,
                     "relative_words_count": relative_words_count,
                     "indefinete_words_count": indefinete_words_count,
                     "total_interrogatives_count": total_interrogatives_count,
                     "reflexive_words_count": reflexive_words_count,
                     "total_prepositions": total_prepositions,
                     "two_words_prepositions_count": two_words_prepositions_count,
                     "total_cordinative_conj": total_cordinative_conj,
                     "total_correlative_conj": total_correlative_conj,
                     "total_quantifiers": total_quantifiers,
                     "total_articles": total_articles,
                     "total_subj_verb_obj_sents": total_subj_verb_obj_sents,
                     "total_subj_obj_verb_sents": total_subj_obj_verb_sents,
                     "total_obj_subj_verb_sents": total_obj_subj_verb_sents,
                     "total_obj_verb_subj_sents": total_obj_verb_subj_sents,
                     "total_verb_subj_obj_sents": total_verb_subj_obj_sents,
                     "total_verb_obj_subj_sents": total_verb_obj_subj_sents,
                     }

    if flatten_features:
        for i, longer_word in enumerate(pad(top_20_longer_words, ("", 0), 20)):
            features_dict["longer_word-{}".format(i)] = longer_word[0]

        for i, hapax_legomenon_w in enumerate(pad(hapax_legomenon, "", 100)):
            features_dict["hapax_legomenon-{}".format(i)] = hapax_legomenon_w

        for i, hapax_dislegomenon_w in enumerate(pad(hapax_dislegomenon, "", 100)):
            features_dict["hapax_dislegomenon-{}".format(i)] = hapax_dislegomenon_w

        for i, bi_single_occurence in enumerate(pad(bi_single_occurences, '', 200)):
            features_dict["bi_single_occurence-{}".format(i)] = bi_single_occurence

        for i, tri_single_occurence in enumerate(pad(tri_single_occurences, '', 200)):
            features_dict["tri_single_occurence-{}".format(i)] = tri_single_occurence

        # dictionary padding
        pending = 10
        i = 0
        for k, v in word_len_freq_top_10.items():
            features_dict["word_len-{}".format(i)] = k
            features_dict["word_freq-{}".format(i)] = v
            pending -= 1
            i += 1

        for i in range(pending):
            features_dict["word_len-{}".format(i)] = 0
            features_dict["word_freq-{}".format(i)] = 0
    else:
        features_dict["longer_word"] = top_20_longer_words
        features_dict["hapax_legomenon"] = hapax_legomenon
        features_dict["hapax_dislegomenon"] = hapax_dislegomenon
        features_dict["bi_single_occurence"] = bi_single_occurences
        features_dict["tri_single_occurence"] = tri_single_occurences
        features_dict["word_len_freq_top_10"] = word_len_freq_top_10

    return features_dict


def pad(l, content, width):
    l.extend([content] * (width - len(l)))
    return l


def syllable_count(word):
    if word == '': return 0
    word = word.lower()
    count = 0
    vowels = "aeiouy"
    if word[0] in vowels:
        count += 1
    for index in range(1, len(word)):
        if word[index] in vowels and word[index - 1] not in vowels:
            count += 1
    if word.endswith("e"):
        count -= 1
    if count == 0:
        count += 1
    return count


def n_grams(seq, n=1):
    """Returns an iterator over the n-grams given a list_tokens"""
    shift_token = lambda i: (el for j, el in enumerate(seq) if j >= i)
    shifted_tokens = (shift_token(i) for i in range(n))
    tuple_ngrams = zip(*shifted_tokens)
    return tuple_ngrams  # if join in generator : (" ".join(i) for i in tuple_ngrams)
