import pandas as pd    # to load dataset
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer  # to encode text to int
from tensorflow.keras.preprocessing.sequence import pad_sequences   # to do padding or truncating
from tensorflow.keras.models import Sequential     # the model
from tensorflow.keras.layers import Embedding, LSTM, Dense, Bidirectional # layers of the architecture
from tensorflow.keras.callbacks import ModelCheckpoint   # save model
from tensorflow.keras.models import load_model   # load saved model
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.pipeline import make_pipeline
from sklearn.pipeline import TransformerMixin
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split
from nltk.corpus import stopwords   # to get collection of stopwords
import joblib 
from collections import OrderedDict
import seaborn as sns
from sklearn import metrics

from lime.lime_text import LimeTextExplainer
from explainer.SvsvlExp_text import SvsvlTextExp

def load_dataset(data):
    df = data
    x_data = data['review']       # Reviews/Input
    y_data = data['sentiment']    # Sentiment/Output
    english_stops = set(stopwords.words('english'))

    # PRE-PROCESS REVIEW
    x_data = x_data.replace({'<.*?>': ''}, regex = True)          # remove html tag
    x_data = x_data.replace({'[^A-Za-z]': ' '}, regex = True)     # remove non alphabet
    #x_data = x_data.apply(lambda review: [w for w in review.split() if w not in english_stops])  # remove stop words
    #x_data = x_data.apply(lambda review: [w.lower() for w in review])   # lower case
    
    # ENCODE SENTIMENT -> 0 & 1
    y_data = y_data.replace('positive', 1)
    y_data = y_data.replace('negative', 0)

    return x_data, y_data

class TextsToSequences(Tokenizer, BaseEstimator, TransformerMixin):
    """ Sklearn transformer to convert texts to indices list 
    (e.g. [["the cute cat"], ["the dog"]] -> [[1, 2, 3], [1, 4]])"""
    def __init__(self,  **kwargs):
        super().__init__(**kwargs)
        
    def fit(self, texts, y=None):
        self.fit_on_texts(texts)
        return self
    
    def transform(self, texts, y=None):
        return np.array(self.texts_to_sequences(texts))

class Padder(BaseEstimator, TransformerMixin):
    """ Pad and crop uneven lists to the same length. 
    Only the end of lists longernthan the maxlen attribute are
    kept, and lists shorter than maxlen are left-padded with zeros
    
    Attributes
    ----------
    maxlen: int
        sizes of sequences after padding
    max_index: int
        maximum index known by the Padder, if a higher index is met during 
        transform it is transformed to a 0
    """
    def __init__(self, maxlen=500):
        self.maxlen = maxlen
        self.max_index = None
        
    def fit(self, X, y=None):
        self.max_index = pad_sequences(X, maxlen=self.maxlen).max()
        return self
    
    def transform(self, X, y=None):
        X = pad_sequences(X, maxlen=self.maxlen)
        X[X > self.max_index] = 0
        return X

def create_model(max_features):
    """ Model creation function: returns a compiled Bidirectional LSTM"""
    model = Sequential()
    model.add(Embedding(max_features, 256))
    model.add(Bidirectional(LSTM(256, dropout=0.2, recurrent_dropout=0.2)))
    model.add(Dense(1, activation='sigmoid'))
    model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])
    
    return model

### 
training = False 

vocab_size = 30000  # Max number of different word, i.e. model input dimension
maxlen = 150  # Max number of words kept at the end of each text

sequencer = TextsToSequences(num_words=vocab_size)
padder = Padder(maxlen)

batch_size = 64
max_features = vocab_size + 1

data = pd.read_csv('imdb.csv')
x_data, y_data = load_dataset(data)
texts_train, texts_test, y_train, y_test = \
train_test_split(x_data, y_data, random_state=42)

if training == True:
    sklearn_lstm = KerasClassifier(build_fn=create_model, epochs=10, batch_size=batch_size, 
                                max_features=max_features, verbose=1)

    # Build the Scikit-learn pipeline
    pipeline = make_pipeline(sequencer, padder, sklearn_lstm)

    pipeline.fit(texts_train, y_train)

    print("Saving the model...")
    joblib.dump(pipeline, 'models/lstm3.pkl')

else:
    pipeline = joblib.load('models/lstm3.pkl')


print('Computing predictions on test set...')
y_preds = pipeline.predict(texts_test)

print('Test accuracy: {:.2f} %'.format(100*metrics.accuracy_score(y_preds, y_test)))

idx = 100
text_sample = texts_test.iloc[idx]
class_names = ['negative', 'positive']


sorted_texts = sorted(texts_test, key=lambda a: len(a))

svsvlexp = SvsvlTextExp(class_names = class_names, feature_selection='lasso')
svsvlexplanation = svsvlexp.explain_instance(sorted_texts[1], pipeline.predict_proba, num_features = 10)
svsvlexplanation.as_pyplot_figure(interacting_features=3, plot_features=3, plot_type="interaction")


lime = LimeTextExplainer(class_names = class_names)
lime_exp = lime.explain_instance(sorted_texts[0], pipeline.predict_proba, num_features = 10, num_samples=50)

weights = OrderedDict(lime_exp.as_list())
lime_weights = pd.DataFrame({'words': list(weights.keys()), 'weights': list(weights.values())})

sns.barplot(x="words", y="weights", data=lime_weights)

print("done!")