import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import spacy
from sklearn.feature_extraction.text import TfidfVectorizer


class CustomAnalyzer(object):

    # load spaCy's english model and define the tokenizer/lemmatizer
    def __init__(self):
        self.nlp_ = spacy.load("en_core_web_sm")

    # allow the class instance to be called just like
    # just like a function and applies the preprocessing and
    # tokenize the document
    def __call__(self, doc):
        doc_clean = doc.lower()
        tokens = self.nlp_(doc_clean)
        return [token.lemma_ if token.lemma_ != "-PRON-" else token.text for token in tokens]


def create_tfidf_vectorizer(**kwargs):
    ngram_range = kwargs.get("ngram_range")
    analyzer = kwargs.get("analyzer")
    # if analyzer == 'word':
    #     print("Using custom word-based analyzer")
    #     analyzer = CustomAnalyzer()
    stop_words = kwargs.get("stop_words")
    # token_pattern = r'\S+'
    return TfidfVectorizer(analyzer=analyzer, lowercase=True, ngram_range=ngram_range, stop_words=stop_words)


def run_tfidf_insights(xx_train, yy_train, **kwargs):
    categories = kwargs.get("categories")
    vectorizer = create_tfidf_vectorizer(**kwargs)
    X = vectorizer.fit_transform(xx_train)
    dfs = top_feats_by_class(X, yy_train, vectorizer.get_feature_names())
    # print(dfs)
    plot_tfidf_classfeats_h(dfs, categories)


def plot_tfidf_classfeats_h(dfs, categories):
    """ Plot the data frames returned by the function plot_tfidf_classfeats(). """
    cat_labels = [cat[0] for cat in categories]
    fig = plt.figure(figsize=(12, 9), facecolor="w")
    x = np.arange(len(dfs[0]))
    for i, df in enumerate(dfs):
        ax = fig.add_subplot(1, len(dfs), i + 1)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.set_frame_on(False)
        ax.get_xaxis().tick_bottom()
        ax.get_yaxis().tick_left()
        ax.set_xlabel("Mean Tf-Idf Score", labelpad=16, fontsize=14)
        ax.set_title("label = " + str(cat_labels[df.label]), fontsize=16)
        ax.ticklabel_format(axis='x', style='sci', scilimits=(-2, 2))
        ax.barh(x, df.tfidf, align='center', color='#3F5D7D')
        ax.set_yticks(x)
        ax.set_ylim([-1, x[-1] + 1])
        yticks = ax.set_yticklabels(df.feature)
        plt.subplots_adjust(bottom=0.09, right=0.97, left=0.15, top=0.95, wspace=0.52)
    plt.show()


def top_mean_feats(Xtr, features, grp_ids=None, min_tfidf=0.1, top_n=25):
    """ Return the top n features that on average are most important amongst documents in rows
        indentified by indices in grp_ids. """
    if grp_ids:
        D = Xtr[grp_ids].toarray()
    else:
        D = Xtr.toarray()

    D[D < min_tfidf] = 0
    tfidf_means = np.mean(D, axis=0)
    return top_tfidf_feats(tfidf_means, features, top_n)


def top_tfidf_feats(row, features, top_n=25):
    """ Get top n tfidf values in row and return them with their corresponding feature names."""
    topn_ids = np.argsort(row)[::-1][:top_n]
    top_feats = [(features[i], row[i]) for i in topn_ids]
    df = pd.DataFrame(top_feats)
    df.columns = ['feature', 'tfidf']
    return df


def top_feats_by_class(Xtr, y, features, min_tfidf=0.1, top_n=25):
    """ Return a list of dfs, where each df holds top_n features and their mean tfidf value
        calculated across documents with the same class label. """
    dfs = []
    labels = np.unique(y)
    for label in labels:
        ids = np.where(y == label)
        feats_df = top_mean_feats(Xtr, features, ids, min_tfidf=min_tfidf, top_n=top_n)
        feats_df.label = label
        dfs.append(feats_df)
    return dfs
