import csv, os, pickle, time, logging, nltk
import matplotlib as mpl

import numpy as np
from tqdm import tqdm

from GenericTools.PlotTools.mpl_tools import load_plot_settings
from GenericTools.StayOrganizedTools.VeryCustomSacred import CustomExperiment

mpl.use('PS')
import matplotlib.pyplot as plt
import tensorflow as tf

tf.compat.v1.disable_eager_execution()

from GenericTools.LeanguageTreatmentTools.nlp import Vocabulary
from ariel_tests.models.lmArielEmbedding import LmArielEmbedding

mpl = load_plot_settings(mpl, figsize=(10, 10))

CDIR = os.path.dirname(os.path.realpath(__file__))
ex = CustomExperiment('-interpolations', base_dir=CDIR, seed=11)
logger = logging.getLogger('mylogger')


@ex.config
def cfg():
    # params
    n_interpolation_sentences = 20
    n_sentence_pairs = 15  # 15
    interpolation_steps = 100  # 100

    latent_dims = [1, 2, 4, 6, 8, 10, 12, 14, 16, 32, 256, 512, 1024]  # , int(1e6)]
    # latent_dims = [2, 1]
    # latent_dims = [1, 2, 3]

    GPU = 1
    GPU_count = 0 if GPU == -1 else 1
    GPU_fraction = .1  # .45

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)

    trained_model_path = r'D:\work\ariel_tests\experiments\good_models\lmariel_16\experiment-2020_01_05_at_12_09_48-embedding_lmariel_16d_v0017\model/'


def timeStructured():
    named_tuple = time.localtime()  # get struct_time
    time_string = time.strftime("%Y-%m-%d-%H-%M-%S", named_tuple)
    return time_string


@ex.automain
def plot_interpolation(diversity=None, latent_dims=None, diversity_std=None):
    # Matplotlib default formatting
    # plt.rc('font', family='serif', size=16)
    # plt.rcParams.update({'legend.fontsize': 'x-small',
    #                      'legend.handlelength': 2,
    #                      'legend.labelspacing': 0.2})
    # # plt.rcParams.update({'text.usetex': True,
    # #                      'text.latex.preamble': [r'\usepackage{amsmath}']})
    # plt.rc('xtick', labelsize='x-small')
    # plt.rc('ytick', labelsize='x-small')

    if diversity is None:
        diversity = [1, 1, 0.956, 0.83933, 0.78266, 0.76933, 0.7533, 0.746, 0.746, 0.746, 0.746, 0.746, 0.746, ]
    if diversity_std is None:
        diversity_std = [0., 0., 0.052, 0.088, 0.077, 0.075, 0.074, 0.074, 0.074, 0.074, 0.074, 0.074, 0.074]

    if isinstance(diversity_std, list):
        diversity_std = np.array(diversity_std)
    if isinstance(diversity, list):
        diversity = np.array(diversity)

    if latent_dims is None:
        latent_dims = [1, 2, 4, 6, 8, 10, 12, 14, 16, 32, 256, 512, 1024, ]

    fig = plt.figure()
    ax = fig.add_subplot(2, 1, 1)

    ax.plot(latent_dims, diversity, color='tab:blue', lw=1)
    plt.xlabel(r'latent dimension $d$')
    plt.ylabel(r'interpolation diversity')

    print(diversity)
    print(diversity_std)
    ax.fill_between(
        latent_dims, diversity - diversity_std / 2, diversity + diversity_std / 2, alpha=0.5, color='tab:blue'
    )

    ax.set_xscale('log')
    ax.set_yticks([.7, .85, 1.])
    for pos in ['right', 'left', 'bottom', 'top']:
        ax.spines[pos].set_visible(False)

    time_string = timeStructured()
    plt.savefig('experiments/{}_log_interpolations.pdf'.format(time_string), bbox_inches='tight')


def main(n_interpolation_sentences,
         n_sentence_pairs,
         interpolation_steps,
         latent_dims,
         trained_model_path):
    # load dataset to test
    grammar_path = os.path.join(CDIR, 'language/house_questions_grammar.cfg')  # 'questions'
    grammar = nltk.data.load('file:' + grammar_path)  # self.oracle.getContextFreeGrammar([grammar_name])
    vocabulary = Vocabulary.fromGrammar(grammar)

    # load trained language model

    directories = os.listdir(trained_model_path)
    model_path = os.path.join(CDIR, trained_model_path + directories[-1])

    # generate 100 sentences and then do the interpolations 
    # between them for different latent_dim

    interpolation_sentences_path = 'data/interpolation_sentences_{}.pkl'.format(n_interpolation_sentences)
    if not os.path.exists(interpolation_sentences_path):
        latent_dim = 16
        lmariel_embedding = LmArielEmbedding(vocabulary, latent_dim=latent_dim)
        lmariel_embedding.load(model_path)
        decoder = lmariel_embedding.getDecoder()

        list_sentences = []
        for _ in range(n_interpolation_sentences):
            nooooise = np.random.rand(latent_dim)
            decoded = decoder.decode(nooooise)
            idx = decoded.index('?') + 1
            list_sentences.append(decoded[:idx])

        with open(interpolation_sentences_path, 'wb') as fp:
            pickle.dump(list_sentences, fp)
        del decoder, lmariel_embedding
    else:
        with open(interpolation_sentences_path, 'rb') as fp:
            list_sentences = pickle.load(fp)

    logger.info(list_sentences)
    pairs_sentences = [np.random.choice(list_sentences, size=2, replace=False).tolist() for _ in
                       range(n_sentence_pairs)]
    logger.info(pairs_sentences)
    logger.info('n pairs: {}'.format(len(pairs_sentences)))

    diversity_per_latent_dim = [0] * len(latent_dims)
    std_per_latent_dim = [0] * len(latent_dims)
    for i, latent_dim in enumerate(latent_dims):
        # plug LM inside AriEL
        lmariel_embedding = LmArielEmbedding(vocabulary, latent_dim=latent_dim)
        lmariel_embedding.load(model_path)
        decoder = lmariel_embedding.getDecoder()
        encoder = lmariel_embedding.getEncoder()

        diversities = []
        diversity_score = 0
        for pair in tqdm(pairs_sentences):
            point_i = encoder.encode(pair[0])
            point_f = encoder.encode(pair[1])
            list_sentences = []

            for alpha in np.linspace(0, 1, interpolation_steps):
                point = alpha * point_i + (1 - alpha) * point_f
                decoded = decoder.decode(point)
                list_sentences.append(decoded)
                logger.info('decoded:  {}'.format(decoded))

            unique_sentences = np.unique(list_sentences)
            # unique_and_grammatical

            diversity_score += len(unique_sentences) / interpolation_steps
            diversities.append(len(unique_sentences) / interpolation_steps)

        diversity_score = diversity_score / n_sentence_pairs
        logger.info('mean: {}, std: {}'.format(np.mean(diversities), np.std(diversities)))
        logger.info('latent_dim: {}, diversity_score: {}'.format(latent_dim, diversity_score))
        diversity_per_latent_dim[i] = diversity_score
        std_per_latent_dim[i] = np.std(diversities)

        time_string = timeStructured()
        PATH_TO_CSV = 'experiments/{}_interpolations.csv'.format(time_string)
        with open(PATH_TO_CSV, 'w') as f:
            writer = csv.writer(f, delimiter='\t')
            writer.writerows(zip(latent_dims, diversity_per_latent_dim, std_per_latent_dim))

        ex.add_artifact(PATH_TO_CSV)

    plot_interpolation(diversity_per_latent_dim, latent_dims, std_per_latent_dim)
