import gzip, itertools, logging, math, os, random, nltk

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PyPDF2 import PdfFileMerger
from scipy.spatial import Delaunay
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm
from umap import UMAP

from ariel_tests.helpers.radar_plots import ComplexRadar
from ariel_tests.language.nlp import tokenize, postprocessSentence
from ariel_tests.language.utils import get_encoded_sentences, get_samples_and_labels, parse2grammarLabel

np.printoptions(precision=5, suppress=True)
logger = logging.getLogger(__name__)


CDIR = os.path.dirname(os.path.realpath(__file__))
DATADIR = os.path.abspath(os.path.join(CDIR, '..', 'data'))


def pickProjectionMethod(projection_method, data, n_components):
    if projection_method == 'tSNE':
        projector = TSNE(n_components=n_components, perplexity=30, random_state=0)
        projections = projector.fit_transform(data)
    elif projection_method == 'PCA':
        projector = PCA(n_components=n_components)
        projections = projector.fit_transform(data)
    elif projection_method == 'UMAP':
        projector = UMAP(n_components=n_components)
        reducer = projector.fit(data)
        projections = reducer.transform(data)
    else:
        raise Exception('Unsupported projection method: %s' % (projection_method))

    return projections


def alpha_shape(points, alpha):
    import shapely.geometry as geometry
    from shapely.geometry import Point
    from shapely.ops import cascaded_union, polygonize

    """
    http://blog.thehumangeo.com/2014/05/12/drawing-boundaries-in-python/

    Compute the alpha shape (concave hull) of a set
    of points.
    @param points: Iterable container of points.
    @param alpha: alpha value to influence the
        gooeyness of the border. Smaller numbers
        don't fall inward as much as larger numbers.
        Too large, and you lose everything!
    """
    if len(points) < 4:
        # When you have a triangle, there is no sense
        # in computing an alpha shape.
        return geometry.MultiPoint(list(points)).convex_hull

    def add_edge(edges, edge_points, coords, i, j):
        """
        Add a line between the i-th and j-th points,
        if not in the list already
        """
        if (i, j) in edges or (j, i) in edges:
            # already added
            return
        edges.add((i, j))
        edge_points.append(coords[[i, j]])

    coords = np.array([point.coords[0]
                       for point in points])
    tri = Delaunay(coords)
    edges = set()
    edge_points = []
    # loop over triangles:
    # ia, ib, ic = indices of corner points of the
    # triangle
    for ia, ib, ic in tri.vertices:
        pa = coords[ia]
        pb = coords[ib]
        pc = coords[ic]
        # Lengths of sides of triangle
        a = math.sqrt((pa[0] - pb[0]) ** 2 + (pa[1] - pb[1]) ** 2)
        b = math.sqrt((pb[0] - pc[0]) ** 2 + (pb[1] - pc[1]) ** 2)
        c = math.sqrt((pc[0] - pa[0]) ** 2 + (pc[1] - pa[1]) ** 2)
        # Semiperimeter of triangle
        s = (a + b + c) / 2.0
        # Area of triangle by Heron's formula
        area = math.sqrt(s * (s - a) * (s - b) * (s - c))
        if area == 0: area = .001
        circum_r = a * b * c / (4.0 * area)
        # Here's the radius filter.
        # print circum_r
        if circum_r < 1.0 / alpha:
            add_edge(edges, edge_points, coords, ia, ib)
            add_edge(edges, edge_points, coords, ib, ic)
            add_edge(edges, edge_points, coords, ic, ia)
    m = geometry.MultiLineString(edge_points)
    triangles = list(polygonize(m))
    return cascaded_union(triangles), edge_points


def plot_polygon(polygon, ax=None, fc='#999999', ec='#000000', fill=True, zorder=0, linewidth=1):
    from descartes import PolygonPatch
    # http://blog.thehumangeo.com/2014/05/12/drawing-boundaries-in-python/
    if ax == None:
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111)
    margin = .0
    x_min, y_min, x_max, y_max = polygon.bounds
    ax.set_xlim([x_min - margin, x_max + margin])
    ax.set_ylim([y_min - margin, y_max + margin])
    patch = PolygonPatch(polygon, fc=fc,
                         ec=ec, fill=fill,
                         zorder=zorder,  # -1
                         linewidth=linewidth)
    ax.add_patch(patch)


def PlotResults(DataClass, projection_method='PCA'):
    figures = {}

    # Visualize embedding in 2D (PCA, t-SNE)
    logger.info('\n\nVisualizing embedding in 2D')

    if DataClass.dataset_name == 'HoME':
        customSentences = ['is it red ?', 'is it pale violet red ?', 'is the object red and blue ?',  # Color-related
                           'is it movable ?', 'is the object recognizable ?', 'is it eatable ?',  # Affordance-related
                           'is the object small ?', 'is it narrow ?', 'is it average-sized ?',  # Size-related
                           'is it an indoor lamp ?', 'is the object a household appliance ?', 'is it a vase ?',
                           # Category-related
                           'is the object very light ?', 'is it slightly heavy ?', 'is it heavy ?']  # Mass-related
    else:
        list_sentences = DataClass.biasedSentences_Test
        customSentences = random.choices(list_sentences, k=15)

    z = []
    for sentence in tqdm(DataClass.biasedSentences_Test + customSentences):
        z.append(DataClass.encoder.encode(sentence))
    z = np.array(z, dtype=np.float32)

    n_components = 2
    projections = pickProjectionMethod(projection_method=projection_method,
                                       data=z,
                                       n_components=n_components)

    projectionsBiased = projections[:len(projections) - len(customSentences)]
    projectionsCustoms = projections[len(projections) - len(customSentences):]

    # Projection for custom sentences
    z = []
    for sentence in tqdm(customSentences):
        z.append(DataClass.encoder.encode(sentence))

    z = np.array(z, dtype=np.float32)
    figures['projection-custom-sentences_' + projection_method] = plt.figure(figsize=(8, 8))
    ax = figures['projection-custom-sentences_' + projection_method].add_subplot(111)
    ax.set_title('%s projection for sentences' % (projection_method))
    ax.scatter(x=projectionsCustoms[:, 0], y=projectionsCustoms[:, 1], color='k')

    for i in range(len(customSentences)):
        ax.annotate(customSentences[i], xy=(projectionsCustoms[i, 0], projectionsCustoms[i, 1]), xytext=(0, 8),
                    textcoords='offset points', ha='center')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])

    # Projection with respect to sentence length
    figures['projection-sentence-length_' + projection_method] = plt.figure(figsize=(8, 8))
    ax = figures['projection-sentence-length_' + projection_method].add_subplot(111)
    ax.set_title('%s projection for sentence length' % (projection_method))
    colors = ['r', 'g', 'b', 'm']
    labels = ['L < 3', '3 <= length < 6', '6 <= length < 10', 'length >= 10']
    nbTokensClasses = []
    for sentence in DataClass.biasedSentences_Test:
        tokens = tokenize(sentence)
        if len(tokens) < 3:
            nbTokensClasses.append(0)
        elif len(tokens) >= 3 and len(tokens) < 6:
            nbTokensClasses.append(1)
        elif len(tokens) >= 6 and len(tokens) < 10:
            nbTokensClasses.append(2)
        else:
            nbTokensClasses.append(3)
    nbTokensClasses = np.array(nbTokensClasses, dtype=np.int)

    for i in range(len(labels)):
        p = projectionsBiased[nbTokensClasses == i]
        ax.scatter(x=p[:, 0], y=p[:, 1], color=colors[i], label=labels[i])
    ax.legend()
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])

    try:
        # Projection with respect to number of attributes in sentence (e.g. adjectives, nouns)
        for attribute in ['adjective', 'noun']:
            figures['projection-attributes-' + attribute + '_' + projection_method] = plt.figure(figsize=(8, 8))
            ax = figures['projection-attributes-' + attribute + '_' + projection_method].add_subplot(111)
            ax.set_title('%s projection for number of attributes \'%s\'' % (projection_method, attribute))
            colors = ['r', 'g', 'b', 'm']
            labels = ['N = 0', 'N = 1', 'N = 2', 'N > 3']
            nbAttributesClasses = []
            parser = nltk.ChartParser(DataClass.grammar)
            for sentence in tqdm(DataClass.biasedSentences_Test):
                tokens = tokenize(sentence)
                tree = list(parser.parse(tokens).__next__())[0]
                nbAttributes = len(list(tree.subtrees(filter=lambda x: x.label() == attribute)))
                if nbAttributes == 0:
                    nbAttributesClasses.append(0)
                elif nbAttributes == 1:
                    nbAttributesClasses.append(1)
                elif nbAttributes == 2:
                    nbAttributesClasses.append(2)
                else:
                    nbAttributesClasses.append(3)
            nbAttributesClasses = np.array(nbAttributesClasses, dtype=np.int)

            for i in range(len(labels)):
                p = projectionsBiased[nbAttributesClasses == i]
                ax.scatter(x=p[:, 0], y=p[:, 1], color=colors[i], label=labels[i])
            ax.legend()
            ax.axes.xaxis.set_ticklabels([])
            ax.axes.yaxis.set_ticklabels([])
    except:
        logger.info('This plot works only for the HQ data')

    # plot different coverages in a radar chart

    data = (DataClass.coverageBiased,
            DataClass.coverageVocabulary,
            DataClass.grammaticalReconstruction,
            DataClass.semantic_reconstruction,
            DataClass.rtUniqueSentences,
            DataClass.rtCorrectAndUniqueSentences,
            DataClass.coverageGrammar,
            DataClass.semantic_generation,
            DataClass.coverageUnbiased,
            DataClass.semantic_generalization,
            # DataClass.len_enc_score,
            # DataClass.gra_enc_score
            )

    variables = ('reconstruction\nbiased',
                 'coverage Vocabulary Generation',
                 'grammatical Reconstruction',
                 'semantic Reconstruction',
                 'unique sentences Generation',
                 'correct and unique sentences Generation',
                 'coverage Grammar Generation',
                 'semantic Generation',
                 'reconstruction unbiased',
                 'semantic unbiased',
                 # 'length clustering Encoding',
                 # 'grammar clustering Encoding'
                 )

    ranges = [(0, 1)] * len(variables)

    # plotting
    figures['spyder-plot-coverages'] = plt.figure(figsize=(9, 9))
    radar = ComplexRadar(figures['spyder-plot-coverages'], variables, ranges)
    radar.plot(data)
    radar.fill(data, alpha=0.2)

    # do the same only for reconstruction

    data = (DataClass.coverageBiased,
            DataClass.coverageUnbiased,
            DataClass.grammaticalReconstruction,
            DataClass.semantic_reconstruction
            )

    variables = ('Reconstruction\nBiased',
                 'Reconstruction\nUnbiased',
                 'grammatical',
                 'semantic'
                 )

    ranges = [(0, 1)] * len(variables)

    # plotting
    figures['spyder-plot-coverages-reconstruction'] = plt.figure(figsize=(9, 9))
    radar = ComplexRadar(figures['spyder-plot-coverages-reconstruction'], variables, ranges)
    radar.plot(data)
    radar.fill(data, alpha=0.2)
    plt.subplots_adjust(top=50, right=50)

    # do the same only for generation

    data = (DataClass.coverageVocabulary,
            DataClass.rtUniqueSentences,
            DataClass.rtCorrectAndUniqueSentences,
            DataClass.coverageGrammar,
            DataClass.semantic_generation)

    variables = ('coverage\nvocabulary',
                 'unique\nsentences',
                 'correct\nand\nunique',
                 'coverage\ngrammar',
                 'semantic')

    ranges = [(0, 1)] * len(variables)

    # plotting
    figures['spyder-plot-coverages-generation'] = plt.figure(figsize=(9, 9))
    radar = ComplexRadar(figures['spyder-plot-coverages-generation'], variables, ranges)
    radar.plot(data)
    radar.fill(data, alpha=0.2)

    # save plots
    for key in figures.keys():
        figures[key].savefig(DataClass.experiment_folder + '/plots/' + key + '.pdf', bbox_inches='tight')


def Interpolations(DataClass):
    pairs = [['is it red ?', 'is the object a household appliance ?'],  # change of grammar rule
             ['is it the average-sized and rosy brown shelving ?', 'is it the large and rosy brown shelving ?'],
             # change of one adjective
             ['is it a cadet blue and average-sized shelving ?', 'is it a average-sized and cadet blue shelving ?']
             # change order of adjectives
             ]

    # save generated samples
    with open(DataClass.experiment_folder + '/text/' + 'interpolation.txt', 'w') as file:
        for pair in pairs:
            z0 = DataClass.encoder.encode(pair[0])
            z1 = DataClass.encoder.encode(pair[1])
            file.write('\n\nFrom:\n' + pair[0] + '\nto\n' + pair[1] + '\n\n')

            nb_samples_test = 11
            for i, x in enumerate(np.linspace(0, 1, nb_samples_test)):
                c = x * z0 + (1 - x) * z1
                sentence = DataClass.decoder.decode(c)
                file.write(str(i) + '\t' + str(x) + '\n')
                file.write(' ' + sentence + '\n')


def LatSpaceAlgebra(DataClass):
    # check the cross products too
    with open(DataClass.experiment_folder + '/text/' + 'linear_algebra.txt', 'w') as file:

        file.write('\n\n\nplus minus plus\n')
        pmp = ['is it red , average-sized and dark khaki ?', 'is it red ?', 'is it blue ?']  # change adjective
        z0 = DataClass.encoder.encode(pmp[0])
        z1 = DataClass.encoder.encode(pmp[1])
        z2 = DataClass.encoder.encode(pmp[2])
        z = z0 - z1 + z2
        z_reply = DataClass.decoder.decode(z)

        for sentence in pmp: file.write('     ' + sentence + '\n')
        file.write('\n z = z0 - z1 + z2            ' + z_reply + '\n')
        file.write(' ideally:                    is it blue , average-sized and dark khaki ?\n')

        file.write('\n\n\nplus plus\n')
        ppp = ['is it heavy ?', 'is the object small ?']  # ideal reply : 'is the object small and heavy ?'
        z0 = DataClass.encoder.encode(ppp[0])
        z1 = DataClass.encoder.encode(ppp[1])
        z = z0 + z1
        z_reply = DataClass.decoder.decode(z)

        for sentence in ppp: file.write('     ' + sentence + '\n')
        file.write('\n z = z0 + z1                 ' + z_reply + '\n')
        file.write(' ideally:                    is it heavy and small ?\n')

        file.write('\n\n\nminus minus\n')
        pmm = ['is it recognizable , light coral and observable ?', 'is it light coral ?',
               'is it recognizable ?']  # idea reply 'is it observable ?'
        z0 = DataClass.encoder.encode(pmm[0])
        z1 = DataClass.encoder.encode(pmm[1])
        z2 = DataClass.encoder.encode(pmm[2])

        z = z0 - z1 - z2
        z_reply = DataClass.decoder.decode(z)

        for sentence in pmm: file.write('     ' + sentence + '\n')
        file.write('\n z = z0 - z1 - z2:           ' + z_reply + '\n')
        file.write(' ideally:                    is it observable ?\n')


def selection_to_scatter_and_plot(ax, selection, dim_1, dim_2, color, label, is_geopandas=False):
    ax.scatter(selection[dim_1], selection[dim_2], s=2, c=color, alpha=.7,
               label=label, zorder=1)

    if is_geopandas:
        from geopandas import GeoSeries
        from shapely.geometry import Point

        # http://blog.thehumangeo.com/2014/05/12/drawing-boundaries-in-python/
        X = np.concatenate([np.array(selection[dim_1])[:, np.newaxis], np.array(selection[dim_2])[:, np.newaxis]],
                           axis=1)
        points = GeoSeries(map(Point, X))
        x = [p.coords.xy[0] for p in points]
        y = [p.coords.xy[1] for p in points]
        alpha = 40.  # 12.3

        concave_hull, _ = alpha_shape(points,
                                      alpha=alpha)
        plot_polygon(concave_hull.buffer(.02), ax, ec=color, fill=False, linewidth=1, zorder=2)


def PlotEncodingGrammarClustering(DataClass=None,
                                  encoded_sentences=None,
                                  grammar_labels=None,
                                  experiments_folder=None,
                                  experiment_folder=None,
                                  ax_out=None,
                                  dims_chosen=None,
                                  projection_method=None):
    if not DataClass == None:

        if hasattr(DataClass, 'encoded_sentences'):
            encoded_sentences = DataClass.encoded_sentences
        else:
            encoded_sentences = get_encoded_sentences(DataClass)

        grammar_labels = DataClass.grammar_labels
        experiments_folder = DataClass.experiment_folder
        experiment_folder = 'None'

    np_e = np.array(encoded_sentences)
    unique_grammar_clusters = np.unique(grammar_labels)

    # calculate different language classes
    n_components = 3
    transformed = pickProjectionMethod(projection_method=projection_method,
                                       data=np_e,
                                       n_components=n_components)

    df_transformed = pd.DataFrame(transformed, index=grammar_labels)

    # plots
    fig = plt.figure(figsize=(18, 18), facecolor='white')

    combinations = itertools.combinations(list(range(n_components)), r=2)
    combinations = list(combinations)
    n_rows = len(combinations)

    for i, (dim_1, dim_2) in enumerate(combinations):

        # plot different language classes
        ax = fig.add_subplot(n_rows, 2, 2 * i + 1)

        ax.set_title('Sentence classes (' + projection_method + ' dimensions '
                     + str(dim_1) + ' vs ' + str(dim_2) + ')')

        cmap = plt.cm.get_cmap('tab20', 20)
        # cmap = plt.cm.get_cmap('Set1', 7)

        mm_scaler = MinMaxScaler()
        scaled_transformed = mm_scaler.fit_transform(df_transformed.values)
        df_scaled_transformed = pd.DataFrame(scaled_transformed, columns=df_transformed.columns,
                                             index=df_transformed.index)
        for cluster_i in reversed(unique_grammar_clusters):
            color = cmap(cluster_i)

            selection = df_scaled_transformed.loc[cluster_i, :]
            label = 'grammar rule %s' % (cluster_i)
            selection_to_scatter_and_plot(ax, selection, dim_1, dim_2, color, label)

        fig.canvas.draw()
        # plt.show()
        plt.legend()

    plt.savefig(experiments_folder + '/plots/Encoder_Grammar_Latent_Representation_' + projection_method + '_s'
                + str(np_e.shape[1]) + '_' + experiment_folder + '.pdf')

    if not dims_chosen == None:
        ax_out = ax_out or plt.gca()
        cmap = plt.cm.get_cmap('tab20', 20)
        mm_scaler = MinMaxScaler()
        scaled_transformed = mm_scaler.fit_transform(df_transformed.values)
        df_scaled_transformed = pd.DataFrame(scaled_transformed, columns=df_transformed.columns,
                                             index=df_transformed.index)

        for cluster_i in reversed(unique_grammar_clusters):
            if not cluster_i == -1:
                color = cmap(cluster_i)
            else:
                color = 'black'
            selection = df_scaled_transformed.loc[cluster_i, :]

            selection_to_scatter_and_plot(ax_out, selection, dims_chosen[0], dims_chosen[1], color, label)

            # fig.canvas.draw()
        # ax_out.legend()

        return ax_out


def PlotEncodingLengthClustering(DataClass=None,
                                 encoded_sentences=None,
                                 length_labels=None,
                                 experiments_folder=None,
                                 experiment_folder=None,
                                 ax_out=None,
                                 dims_chosen=None,
                                 projection_method=None):
    if not DataClass == None:

        if hasattr(DataClass, 'encoded_sentences'):
            encoded_sentences = DataClass.encoded_sentences
        else:
            encoded_sentences = get_encoded_sentences(DataClass)

        length_labels = DataClass.length_labels
        experiments_folder = DataClass.experiment_folder
        experiment_folder = 'None'

    np_e = np.array(encoded_sentences)
    unique_length_clusters = np.unique(length_labels)

    # calculate different language classes
    n_components = 3
    transformed = pickProjectionMethod(projection_method=projection_method,
                                       data=np_e,
                                       n_components=n_components)
    df_transformed = pd.DataFrame(transformed, index=length_labels)
    mm_scaler = MinMaxScaler()
    scaled_transformed = mm_scaler.fit_transform(df_transformed.values)
    df_transformed = pd.DataFrame(scaled_transformed, columns=df_transformed.columns,
                                  index=df_transformed.index)

    # plots
    fig = plt.figure(figsize=(18, 18), facecolor='white')

    combinations = itertools.combinations(list(range(n_components)), r=2)
    combinations = list(combinations)
    n_rows = len(combinations)

    for i, (dim_1, dim_2) in enumerate(combinations):

        # plot different language classes
        ax = fig.add_subplot(n_rows, 2, 2 * i + 1)

        ax.set_title('Sentence classes (' + projection_method + ' dimensions '
                     + str(dim_1) + ' vs ' + str(dim_2) + ')')
        cmap = plt.cm.get_cmap('tab20', 20)
        for cluster_i in unique_length_clusters:
            color = cmap(cluster_i + 5)
            selection = df_transformed.loc[cluster_i, :]
            percentile = int((100 - 25) * cluster_i / 3 + 25)

            label = 'length percentile %s-%s' % (percentile - 25, percentile)
            selection_to_scatter_and_plot(ax, selection, dim_1, dim_2, color, label)

        fig.canvas.draw()
        plt.legend()
    plt.savefig(
        experiments_folder + '/plots/Encoder_Length_Latent_Representation_' + projection_method + '_' + experiment_folder + '.pdf')

    if not dims_chosen == None:
        ax_out = ax_out or plt.gca()

        cmap = plt.cm.get_cmap('tab20', 20)

        for cluster_i in reversed(unique_length_clusters):
            color = cmap(cluster_i + 5)
            selection = df_transformed.loc[cluster_i, :]
            percentile = int((100 - 25) * cluster_i / 3 + 25)

            label = 'length percentile %s-%s' % (percentile - 25, percentile)
            selection_to_scatter_and_plot(ax_out, selection, dims_chosen[0], dims_chosen[1], color, label)

        return ax_out


def PlotDecodingGrammarClustering(DataClass=None,
                                  samples=None,
                                  labels=None,
                                  lat_dim=None,
                                  experiments_folder=None,
                                  experiment_folder=None,
                                  ax_out=None,
                                  dims_chosen=None):
    if not DataClass == None:

        if hasattr(DataClass, 'samples'):
            samples = DataClass.samples
            parses = DataClass.parses

            labels = DataClass.samples_grammar_rule
        else:
            samples, labels = get_samples_and_labels(DataClass)

        lat_dim = DataClass.lat_dim
        experiments_folder = DataClass.experiment_folder
        experiment_folder = 'None'

        np.save(experiments_folder + '/text/samples.npy', samples)
        np.save(experiments_folder + '/text/samples_grammar_labels.npy', labels)

    unique_grammar_clusters = np.unique(labels)
    if 'lmariel' in experiment_folder: unique_grammar_clusters = list(reversed(unique_grammar_clusters))

    # plots
    fig = plt.figure(figsize=(18, 18), facecolor='white')
    n_rows = 3

    df_samples = pd.DataFrame(samples, index=labels)
    mm_scaler = MinMaxScaler()
    scaled_samples = mm_scaler.fit_transform(df_samples.values)
    df_samples = pd.DataFrame(scaled_samples, columns=df_samples.columns, index=df_samples.index)

    cmap = plt.cm.get_cmap('tab20', 20)

    for i in range(n_rows):

        for _ in range(10):
            dim_1, dim_2 = np.random.choice(lat_dim, 2, replace='False')
            if not dim_1 == dim_2:
                break

        ax = fig.add_subplot(n_rows, 2, 2 * i + 1)
        ax.set_title('Sentence classes\n(dimensions '
                     + str(dim_1) + ' vs ' + str(dim_2) + ')')

        for cluster_i in unique_grammar_clusters:
            if not cluster_i == -1:
                color = cmap(cluster_i)
            else:
                color = 'black'
            selection = df_samples.loc[cluster_i, :]
            label = 'grammar rule %s' % (cluster_i)
            selection_to_scatter_and_plot(ax, selection, dim_1, dim_2, color, label)

            plt.legend()

    plt.savefig(experiments_folder + '/plots/Decoder_Grammar_Latent_Representation_' + experiment_folder[:-3] + '.pdf')

    if not dims_chosen == None:
        ax_out = ax_out or plt.gca()

        cmap = plt.cm.get_cmap('tab20', 20)

        for cluster_i in unique_grammar_clusters:
            if not cluster_i == -1:
                color = cmap(cluster_i)
                legend = '{} adjectives'.format(cluster_i)
                if cluster_i == 1:
                    legend = '{} adjective'.format(cluster_i)
            else:
                color = 'black'
                legend = 'ungrammatical'
            selection = df_samples.loc[cluster_i, :]

            selection_to_scatter_and_plot(ax_out, selection, dims_chosen[0], dims_chosen[1], color, legend)

        # fig.canvas.draw()
        # ax_out.legend()

        return ax_out


def PlotLsnnVoltageAndBias(DataClass=None,
                           samples=None,
                           labels=None,
                           lat_dim=None,
                           experiments_folder=None,
                           experiment_folder=None,
                           ax_out=None,
                           dims_chosen=None,
                           projection_method=None):
    # Decoding

    if not DataClass == None:

        if hasattr(DataClass, 'samples'):
            samples = DataClass.samples
            parses = DataClass.parses

            labels = []
            for parse in parses:
                labels.append(parse2grammarLabel(parse))
        else:
            samples, labels = get_samples_and_labels(DataClass)

        lat_dim = DataClass.lat_dim
        experiments_folder = DataClass.experiment_folder
        length_labels = DataClass.length_labels
        experiment_folder = 'None'

        np.save(experiments_folder + '/text/samples.npy', samples)
        np.save(experiments_folder + '/text/samples_grammar_labels.npy', labels)

    state_size = DataClass.decoder.model.get_layer("decode-rnn-0").cell.state_size[0]
    c_state_size = np.cumsum(state_size).tolist()
    states = np.split(samples, c_state_size, axis=1)[:-1]

    if experiment_folder == 'PCA':
        for i, state in enumerate(states):
            lat_dim = state.shape[1]
            PlotDecodingGrammarClustering(DataClass=None,
                                          samples=state,
                                          labels=labels,
                                          lat_dim=lat_dim,
                                          experiments_folder=experiments_folder,
                                          experiment_folder=str(i),
                                          ax_out=ax_out,
                                          dims_chosen=dims_chosen)

    # Encodings

    if not DataClass == None:

        if hasattr(DataClass, 'encoded_sentences'):
            encoded_sentences = DataClass.encoded_sentences
        else:
            encoded_sentences = get_encoded_sentences(DataClass)

        grammar_labels = DataClass.grammar_labels
        length_labels = DataClass.length_labels
        experiments_folder = DataClass.experiment_folder
        experiment_folder = 'None'

    state_size = DataClass.decoder.model.get_layer("decode-rnn-0").cell.state_size[0]
    c_state_size = np.cumsum(state_size).tolist()
    states = np.split(encoded_sentences, c_state_size, axis=1)[:-1]

    for i, state in enumerate(states):
        lat_dim = state.shape[1]
        PlotEncodingGrammarClustering(DataClass=None,
                                      encoded_sentences=state,
                                      grammar_labels=grammar_labels,
                                      experiments_folder=experiments_folder,
                                      experiment_folder=str(i),
                                      ax_out=ax_out,
                                      dims_chosen=dims_chosen,
                                      projection_method=projection_method)

        PlotEncodingLengthClustering(DataClass=None,
                                     encoded_sentences=state,
                                     length_labels=length_labels,
                                     experiments_folder=experiments_folder,
                                     experiment_folder=str(i),
                                     ax_out=ax_out,
                                     dims_chosen=dims_chosen,
                                     projection_method=projection_method)


def mergePlotsInSummary(experiments_folder):
    # pdfs = ['file1.pdf', 'file2.pdf', 'file3.pdf', 'file4.pdf']
    pdfs = os.listdir(experiments_folder + '/plots/')
    pdfs_paths = [str(experiments_folder) + '/plots/' + pdf for pdf in pdfs]
    merger = PdfFileMerger()

    for pdf in pdfs_paths:
        merger.append(pdf)

    plot_filepath = experiments_folder + '/plots/00000_summary_00000.pdf'
    merger.write(plot_filepath)
    merger.close()

    return plot_filepath


def SaveText(experiment_folder, DataClass):
    # save generated samples
    with open(experiment_folder + '/text/' + 'samples_and_sentence.txt', 'w') as file:
        file.write('\n')
        for i, (sample, sentence, parse) in enumerate(zip(DataClass.samples, DataClass.sentences, DataClass.parses)):
            file.write(str(i) + '\n')
            file.write('sample:   ' + str(sample) + '\n')
            file.write('sentence: ' + sentence + '\n')
            file.write('parse:    ' + str(parse) + '\n')
            file.write('\n')

    # I don't write the sample here, since it can become quite massive and the .txt hard to read
    file_sp = open(experiment_folder + '/text/' + 'sampled_sentences_and_parses.txt', 'w')
    file_s = open(experiment_folder + '/text/' + 'sampled_sentences.txt', 'w')
    file_si = open(experiment_folder + '/text/' + 'sampled_sentences_and_indices.txt', 'w')
    file_sp.write('\n')

    for i, (sample, sentence, parse) in enumerate(zip(DataClass.samples, DataClass.sentences, DataClass.parses)):
        file_sp.write(str(i) + '\n')
        file_sp.write('sentence: ' + sentence + '\n')
        file_sp.write('parse:    ' + str(parse) + '\n')
        file_sp.write('\n')

        file_s.write(sentence + '\n')

        file_si.write(sentence + '\n')
        file_si.write(str(DataClass.vocabulary.sentenceToIndices(sentence)) + '\n')

    file_si.close()
    file_sp.close()
    file_s.close()


def CompareTexts(
        experiment_folder=None, DataClass=None,
        texts_path='experiments/experiment/text/sampled_sentences.txt'
):
    if DataClass is None and experiment_folder is None:
        # gzipDatasetFilepath = 'data/biased_train.gz'  # HoME
        gzipDatasetFilepath = os.path.join(DATADIR, 'GW_questions_train.gz')        # GW
    else:
        texts_path = experiment_folder + '/text/' + 'sampled_sentences.txt'
        gzipDatasetFilepath = DataClass.biasedFilename_train

    file_s = open(texts_path, 'r')

    line = file_s.readline()
    generated_lines = []
    while line:
        line = file_s.readline().replace('\n', '')
        qm = line.index('?')+1 if '?' in line else None
        generated_lines.append(line[:qm])

    file_s.close()

    unique_generated = np.unique(generated_lines)
    n_samples = len(generated_lines)
    n_unique = len(unique_generated)

    f = gzip.open(gzipDatasetFilepath, 'rb')

    sentences_in_dataset = []
    lengths = []
    for line in tqdm(f):
        sentence = line.strip()
        ps = postprocessSentence(sentence.decode('windows-1252'))  # # .decode("utf-8")
        length = len(ps.split(' '))
        lengths.append(length)

        sentences_in_dataset.append(ps)
    f.close()

    unique_train = np.unique(sentences_in_dataset)

    count = 0
    for ps in tqdm(unique_train):
        if ps in unique_generated:
            count += 1
    logger.info('    np.max(lengths):                      {}'.format(np.max(lengths)))
    logger.info('    np.mean(lengths):                     {}'.format(np.mean(lengths)))
    logger.info('    generated_lines:                      {}'.format(generated_lines))
    logger.info('    n_samples: {}, n_unique: {}'.format(n_samples, n_unique))
    logger.info('    unique and in the training set:       {}'.format(count))
    logger.info('    ratio unique and in the training set: {}'.format(count / n_samples))
    logger.info('    ratio unique of generated:            {}'.format(n_unique / n_samples))
    logger.info('    ratio unique and in the training set: {}'.format(count / n_unique))

    return count / n_samples
