

import gzip, json, logging, os, pickle, pprint, nltk
import numpy as np
import pandas as pd
import tensorflow.keras.backend as K
from nltk import CFG
from tensorflow.python.framework import ops
from tqdm import tqdm
import urllib.request


from ariel_tests.language.nlp import Vocabulary, tokenize, preprocessSentence
from ariel_tests.helpers.utils import timeStructured
from ariel_tests.models.getEmbeddings import getEmbedding
from ariel_tests.language.main_biased import run_biased_generation
from ariel_tests.language.main_unbiased import run_unbiased_generation

logger = logging.getLogger(__name__)

CDIR = os.path.dirname(os.path.realpath(__file__))
CDIR_, _ = os.path.split(CDIR)
DATADIR = os.path.join(CDIR_, 'data')
LANGUAGEDIR = os.path.join(CDIR_, 'language')


def download_GW():

    logger.info('Beginning GW files download...')

    train_url = 'https://florian-strub.com/guesswhat.train.jsonl.gz'
    test_url = 'https://florian-strub.com//guesswhat.test.jsonl.gz'
    val_url = 'https://florian-strub.com//guesswhat.valid.jsonl.gz'

    # test_url doesn't work anymore, I contacted them
    for url in [train_url, val_url]:
        logger.info(url)
        _, file_name = os.path.split(url)
        destination = os.path.join(DATADIR, file_name)
        if not os.path.isfile(destination):
            urllib.request.urlretrieve(url, destination)

def reformat_GW():

    # FIXME: 'test',
    for key in ['train', 'valid']:

        jsonfilename = os.path.join(DATADIR, 'guesswhat.' + key + '.jsonl.gz')
        destination ='GW_questions_'+ key + '.gz'

        if not os.path.isfile(destination):
            questions = []
            with gzip.open(jsonfilename) as f:
                for line in tqdm(f):
                    line = line.decode('cp437')
                    game = json.loads(line.strip('\n'))
                    for interaction in game['qas']:
                        questions.append(interaction['q'])

            datasetFilename = os.path.join(DATADIR, destination)
            with gzip.open(datasetFilename, mode='wt') as f:
                for sentence in tqdm(questions):
                    try:
                        sentence = preprocessSentence(sentence)
                        f.write(sentence + '\r\n')
                    except: pass


def selectDataset(dataset_name='HQ'):
    if dataset_name == 'HQ':
        """ HQ for House Questions """
        
        run_biased_generation()
        run_unbiased_generation()

        grammar_path = "file://" + os.path.join(LANGUAGEDIR, "house_questions_grammar.cfg")
        grammar = nltk.load(grammar_path)
        vocabulary = Vocabulary.fromGrammar(grammar)
        grammar_name = grammar_path

        # Define Data Paths

        unbiasedFilename = os.path.join(DATADIR, 'questions_unbiased.gz')
        biasedFilename_train = os.path.join(DATADIR, 'questions_biased_train.gz')
        biasedFilename_test = os.path.join(DATADIR, 'questions_biased_test.gz')
        biasedFilename_val = os.path.join(DATADIR, 'questions_biased_val.gz')

    elif dataset_name == 'GW':

        if not os.path.isfile(os.path.join(DATADIR, 'guesswhat.valid.jsonl.gz')):
            download_GW()
        if not os.path.isfile(os.path.join(DATADIR, 'GW_questions_valid.gz')):
            reformat_GW()


        # FIXME: GW_questions_test
        unbiasedFilename = os.path.join(DATADIR, 'GW_questions_valid.gz')
        biasedFilename_train = os.path.join(DATADIR, 'GW_questions_train.gz')
        # FIXME: GW_questions_test
        biasedFilename_test = os.path.join(DATADIR, 'GW_questions_valid.gz')
        biasedFilename_val = os.path.join(DATADIR, 'GW_questions_valid.gz')
        dataFilenames = [unbiasedFilename, biasedFilename_train, biasedFilename_test, biasedFilename_val]

        vocabulary = Vocabulary.fromGz(dataFilenames)

        grammar_string = "S -> \"" + "\" | \"".join(vocabulary.tokens) + "\"\n"

        grammar = CFG.fromstring(grammar_string)
        grammar_name = 'dummy grammar'

    elif dataset_name == 'REBER':

        grammar_path = os.path.join(LANGUAGEDIR, 'REBER_grammar.cfg')
        grammar = nltk.load(grammar_path)
        vocabulary = Vocabulary.fromGrammar(grammar)
        grammar_name = grammar_path

        unbiasedFilename = os.path.join(DATADIR, 'REBER_unbiased.gz')
        biasedFilename_train = os.path.join(DATADIR, 'REBER_biased_train.gz')
        biasedFilename_test = os.path.join(DATADIR, 'REBER_biased_test.gz')
        biasedFilename_val = os.path.join(DATADIR, 'REBER_biased_val.gz')
    else:
        raise NotImplementedError

    return grammar_name, grammar, vocabulary, unbiasedFilename, biasedFilename_train, biasedFilename_test, biasedFilename_val


class DataClass(object):

    def __init__(self, nb_samples_test, dataset_name='HQ'):

        self.config = {}

        self.dataset_name = dataset_name
        self.nb_samples_test = nb_samples_test
        grammar_name, self.grammar, self.vocabulary, self.unbiasedFilename, self.biasedFilename_train, self.biasedFilename_test, self.biasedFilename_val = selectDataset(
            dataset_name)
        self.config['grammar_name'] = grammar_name

        datasetFeaturesFilename = os.path.join(DATADIR, 'datasetsFeatures.csv')
        if not os.path.isfile(datasetFeaturesFilename):
            column_names = ['dataset_names', 'length_percentiles', 'initial_symbol_median',
                            'random_word_frequency_median', 'random_word_chosen_for_grammar', 'nbLines']
            datasetsFeatures = {i: [] for i in column_names}

            datasetsHistory = pd.DataFrame(datasetsFeatures, columns=column_names)
            datasetsHistory.set_index(['dataset_names'])
            datasetsHistory = datasetsHistory.append({'dataset_names': 'csvPath', 'nbLines': datasetFeaturesFilename},
                                                     ignore_index=True)

            datasetsHistory.to_csv(datasetFeaturesFilename, index=None,
                                   header=True)  # Don't forget to add '.csv' at the end of the path
        else:
            datasetsHistory = pd.read_csv(datasetFeaturesFilename)

        datasetsHistoryNames = datasetsHistory[['dataset_names']]
        if dataset_name in datasetsHistoryNames:
            self.nbLines = datasetsHistory.loc[[dataset_name], ['nbLines']]
        else:
            # Load biased dataset from file
            biasedSentences_Train = []
            with gzip.open(self.biasedFilename_train, mode='rt') as f:
                for line in f.readlines():
                    sentence = line.strip()
                    if len(sentence) > 0:
                        biasedSentences_Train.append(sentence)
            logger.info('\n\nNumber of sentences from biased dataset: %d' % (len(biasedSentences_Train)))

            self.nbLines = len(biasedSentences_Train)
            datasetsHistory = datasetsHistory.append({'dataset_names': dataset_name, 'nbLines': self.nbLines},
                                                     ignore_index=True)

        self.config['nbBiasedSentences'] = self.nbLines

        # Load biased dataset from file
        biasedSentences_test = []
        with gzip.open(self.biasedFilename_test, mode='rt') as f:
            for line in f.readlines():
                sentence = line.strip()
                if len(sentence) > 0:
                    biasedSentences_test.append(sentence)
        logger.info('\n\nNumber of sentences from biased dataset: %d' % (len(biasedSentences_test)))
        self.config['nbBiasedSentences'] = len(biasedSentences_test)

        # NOTE: only keep a limited number of sentences for the tests
        self.config['nbMaxSentences_biased'] = nb_samples_test
        self.biasedSentences_Test = biasedSentences_test[-nb_samples_test:]

        # Load unbiased dataset from file
        unbiasedSentences = []
        with gzip.open(self.unbiasedFilename, mode='rt') as f:
            for line in f.readlines():
                sentence = line.strip()
                if len(sentence) > 0:
                    unbiasedSentences.append(sentence)
        logger.info('\n\nNumber of sentences from unbiased dataset: %d' % (len(unbiasedSentences)))
        self.config['nbUnbiasedSentences'] = len(unbiasedSentences)

        # NOTE: only keep a limited number of sentences
        self.config['nbMaxSentences_unbiased'] = nb_samples_test
        self.unbiasedSentences_Test = unbiasedSentences[-nb_samples_test:]

        # pre-compute to which grammar rule each Test sentence belongs to
        parse_filename = 'labels_grammar_biasedTest_{}_{}.npy'.format(dataset_name, nb_samples_test)
        parse_filename = os.path.join(DATADIR, parse_filename)
        if not os.path.isfile(parse_filename):
            grammar_labels = label_by_grammar(self.biasedSentences_Test, self.grammar)
            np.save(parse_filename, grammar_labels)
        else:
            grammar_labels = np.load(parse_filename)

        logger.warning(str(grammar_labels))
        logger.warning(str(np.unique(grammar_labels)))
        self.grammar_labels = grammar_labels

        # pre-compute to which grammar rule each Test sentence belongs to
        length_filename = 'labels_length_biasedTest_{}_{}.npy'.format(dataset_name, nb_samples_test)
        length_filename = os.path.join(DATADIR, length_filename)
        if not os.path.isfile(length_filename):
            length_labels = label_by_length(self.biasedSentences_Test, n_clusters=4)
            np.save(length_filename, length_labels)
        else:
            length_labels = np.load(length_filename)

        logger.warning(str(length_labels))
        logger.warning(str(np.unique(length_labels)))
        self.length_labels = length_labels

    def get_data(self, variable_name, variable):
        self.__dict__.update({variable_name: variable})


def get_encoded_sentences(DataClass):
    encoded_sentences = []
    for sentence in tqdm(DataClass.biasedSentences_Test):
        z = DataClass.encoder.encode(sentence)
        encoded_sentences.append(z)

    return encoded_sentences


def label_by_grammar(sentences, grammar):
    parser = nltk.ChartParser(grammar)
    grammar_labels = []
    for sentence in sentences:
        tokens = tokenize(sentence)
        # print(tokens)
        # print(list(parser.parse(tokens)))
        parse = list(parser.parse(tokens))
        if not parse:
            parsed = str(list(parse))
        else:
            parsed = str(list(parse[0]))

        grammar_labels.append(parse2grammarLabel(parsed))

    return grammar_labels


def get_samples_and_volume(DataClass):
    exp_folder = DataClass.experiment_folder  # 'experiments/experiment-2019-07-31-at-09:36:45-embedding_lmariel_gru_16d_2nrl_Q_activationlinear_v1.pkl-no-LM'  #
    zmin = np.load(exp_folder + '/text/' + 'hypercube-zmin.npy')
    zmax = np.load(exp_folder + '/text/' + 'hypercube-zmax.npy')

    zmax = np.expand_dims(zmax, 1)
    zmin = np.expand_dims(zmin, 1)

    alpha = np.random.rand(zmax.shape[0], DataClass.nb_samples_test)

    samples = alpha * zmax + (1 - alpha) * zmin
    samples = np.transpose(samples)
    volume = 1

    return samples, volume


def get_samples_and_labels(DataClass):
    samples, _ = get_samples_and_volume(DataClass)

    sentences = []
    parses = []
    grammar_labels = []

    parser = nltk.ChartParser(DataClass.grammar)

    for sample in samples:
        sentence = DataClass.decoder.decode(sample)

        if not sentence in sentences:
            tokens = tokenize(sentence)
            parsed = list(parser.parse(tokens))
        else:
            index = sentences.index(sentence)
            parsed = parses[index]

        sentences.append(sentence)
        parses.append(parsed)
        grammar_labels.append(parse2grammarLabel(parsed))

    return samples, grammar_labels


def parse2grammarLabel(parse):
    parse = str(parse)
    if '[]' in parse:
        label = -1
    else:
        # if "'adjective'," in parse:
        nbAdjectives = parse.count("'adjective',")
        label = nbAdjectives

    return label


def label_by_length(sentences, n_clusters=5):
    lengths = [len(sentence) for sentence in sentences]

    percentiles = []
    for i in range(n_clusters):
        percentile = np.percentile(lengths, i / n_clusters * 100)
        percentiles.append(percentile)

    length_labels = []
    for length in lengths:
        check = [length < percentile for percentile in percentiles]
        length_labels.append(sum(check))

    return length_labels


def make_directories(time_string=None):
    if time_string == None:
        time_string = timeStructured()

    experiments_folder = os.path.join(CDIR_, "experiments")
    experiment_folder = os.path.join(experiments_folder, 'experiment-' + time_string)
    model_folder = os.path.join(experiment_folder, 'model')
    log_folder = os.path.join(experiment_folder, 'log')
    good_model_folder = os.path.join(experiments_folder, 'good_models')
    generator_folder = os.path.join(experiment_folder, 'text')
    plots_folder = os.path.join(experiment_folder, 'plots')
    files_folder = os.path.join(experiment_folder, 'files')

    folder_paths = [experiments_folder, experiment_folder, model_folder, log_folder,
                    good_model_folder, generator_folder, plots_folder, files_folder]
    for path in folder_paths:
        if not os.path.isdir(path):
            os.mkdir(path)

    return experiment_folder


def alreadyTrained(partial_string, embedding_filename, previousEmbeddingPath, DataClass):
    isAlreadyTrained = False
    embedding = []

    if 'experiment' in partial_string:
        logger.info('hey')
        logger.info('wtf')
        # load from an old experiment
        # partial_string = 'experiment-2019-04-08-at-19:16:37-'
        directories = os.listdir(os.path.join(CDIR_, 'experiments'))
        logger.info(directories)
        where_to_load_from = ''
        for one_dir in directories:
            if partial_string in one_dir:
                where_to_load_from = 'experiments/' + one_dir + '/model/'

        logger.info('CDIR_: ', CDIR_)
        logger.info('where_to_load_from: ', where_to_load_from)

        where_to_load_from = os.path.join(CDIR_, where_to_load_from)
        logger.info('where_to_load_from: ', where_to_load_from)

        h5_path = [path for path in os.listdir(where_to_load_from) if '.h5' in path][0]
        h5_path = h5_path.replace('.h5', '')
        h5_path = os.path.join(where_to_load_from, h5_path)
        K.clear_session()
        ops.reset_default_graph()

        embedding = getEmbedding(model_filename=embedding_filename,
                                 gzip_filename_train=DataClass.biasedFilename_train,
                                 gzip_filename_val=DataClass.biasedFilename_val,
                                 grammar=DataClass.grammar,
                                 epochs=0,
                                 steps_per_epoch=0,
                                 batch_size=64)
        embedding.load(h5_path)
        embedding.save(embedding_filename)

        isAlreadyTrained = True

    if 'previous' in partial_string:
        pkl_file = open(previousEmbeddingPath, 'rb')

        embedding = pickle.load(pkl_file)
        embedding.save(embedding_filename)

        isAlreadyTrained = True

    return isAlreadyTrained, embedding


def test():
    partial_string = 'previous'

    # load from the previous experiment
    print(os.listdir(os.path.join(CDIR, 'experiments')))

    if partial_string == 'previous':
        d = './experiments'
        directories = [os.path.join(d, o) for o in os.listdir(d)
                       if os.path.isdir(os.path.join(d, o)) and 'experiment' in o]

        print('')
        for d in directories:
            print(d)

        where_to_load_from = directories[-1] + '/model/'
        print(where_to_load_from)
        pkl_path = [path for path in os.listdir(where_to_load_from) if 'pkl' in path][0]
        pkl_file = open(where_to_load_from + pkl_path, 'rb')

        import pickle
        embedding = pickle.load(pkl_file)
        # embedding.save(embedding_filename)


def findMethodNameInString(string):
    if 'vae' in string:
        methodName = 'vae'
    elif 'ae' in string:
        methodName = 'ae'
    elif 'arithmetic' in string:
        methodName = 'arithmetic'
    elif 'lmariel' in string:
        methodName = 'lmariel'
    elif 'transformer' in string:
        methodName = 'transformer'
    elif 'None' in string:
        methodName = 'None'
    else:
        print(string)
        raise AttributeError

    if 'no-LM' in string:
        methodName += '_no_LM'
    else:
        methodName += '_LM'

    return methodName


def samples_and_grammar_labels_from_file(text_path):
    # save generated samples

    grammar_labels = []
    samples = []
    sample = []
    sampleOpen = False
    with open(text_path, 'r') as file:

        for line in file:
            if 'sample' in line:
                sample = line[10:]
                sampleOpen = True

            if sampleOpen and not 'sample' in line:
                sample += line

            if ']' in line:
                sample = sample.replace("\n", "").replace("[", "").replace("]", "")
                if sampleOpen:
                    np_array = np.fromstring(sample, dtype=float, sep=' ')
                    samples.append(np_array)

                sampleOpen = False
                sample = ''

            if 'parse' in line:
                line2parse = line[6:]
                grammar_labels.append(parse2grammarLabel(line2parse))

    return samples, grammar_labels


def SaveConfig(config_path, config):
    # save configuration
    path = os.path.join(CDIR_, config_path + '/config.txt')
    with open(path, 'w') as file:
        pp = pprint.PrettyPrinter(indent=4, stream=file)
        pp.pprint(config)

def LoadConfig(config_path):

    with open(config_path, 'r') as f:
        s = f.read().replace("'", "\"")
        s = s.replace('\t','').replace('\n','').replace(' ','').replace('\"\"\"','\"')
        s = s.replace(',}','}').replace(',]',']').replace('\"[','[').replace(']\"',']')
        json_data = json.loads(s)

    return json_data


def check_studies_are_unique(studies):
    """ make sure all the studies in the list have a unique name so the
    outputs don't override """

    new_studies = []
    for study in studies:
        random_string = ''.join([str(r) for r in np.random.choice(10, 4)])
        new_studies.append([study[0] + '_v' + random_string, study[1]])

    return new_studies


if __name__ == '__main__':
    selectDataset('GW')