import copy, logging, os, pathlib, time, shutil

from shutil import copyfile
from time import strftime, localtime
import numpy as np

os.environ['HDF5_DISABLE_VERSION_CHECK'] = '2'

import tensorflow as tf

tf.compat.v1.experimental.output_all_intermediates(True)

import tensorflow.keras.backend as K
from tensorflow.python.framework import ops

from ariel_tests.helpers.VeryCustomSacred import CustomExperiment, ChooseGPU
from ariel_tests.helpers.utils import setReproducible

from ariel_tests.test_analysis.visualization import PlotResults, PlotLsnnVoltageAndBias, mergePlotsInSummary, SaveText, \
    PlotDecodingGrammarClustering, PlotEncodingGrammarClustering, PlotEncodingLengthClustering, CompareTexts
from ariel_tests.models.getEmbeddings import getEmbedding
from ariel_tests.test_analysis.evaluation_studies import ReconstructionAndGeneralizationStudy, \
    DecodingStudy, EncodingStudy
from ariel_tests.language.utils import make_directories, DataClass, alreadyTrained, SaveConfig, check_studies_are_unique

CDIR = os.path.dirname(os.path.realpath(__file__))
random_string = ''.join([str(r) for r in np.random.choice(10, 4)])
ex = CustomExperiment(random_string + '-ariel', base_dir=CDIR, seed=11)
logger = logging.getLogger('mylogger')

@ex.config
def cfg():
    # parameters
    batch_size = 256

    epochs = 0  # 20 # 10 # 2 #
    steps_per_epoch = 10  # 'all'  # 10  #
    dataset_name = 'GW'  # HQ for House Questions, GW for GuessWhat?!

    do_tests = True
    nb_samples_test = 10  # 100 # 10000 #

    GPU = 0
    GPU_count = 1  # 0 if GPU == -1 else 1
    GPU_fraction = .95

    seed = 0

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU)

    # config = tf.ConfigProto(device_count={'GPU': GPU_count})
    # config.gpu_options.per_process_gpu_memory_fraction = GPU_fraction
    # set_session(tf.Session(config=config))

    # small latent space
    studies_GPU0 = [
        ['embedding_lmariel_16d', 'new'],
        ['embedding_transformer_16d', 'new'],
        ['embedding_vae_gru_16d_75keep_2nrl_activationtanh', 'new'],
        ['embedding_ae_gru_16d_75keep_2nrl_activationtanh', 'new'],
    ]


    # large latent space
    studies_GPU1 = [
        # ['embedding_lmariel_512d', 'new'],
        # ['embedding_ae_gru_512d_75keep_2nrl_activationtanh', 'new'],
        # ['embedding_vae_gru_512d_75keep_2nrl_activationtanh', 'new'],
        ['embedding_transformer_512d', 'new'],
    ]

    if GPU == 0:
        studies = studies_GPU0
    elif GPU in [1, -1]:
        studies = studies_GPU1
    else:
        raise NotImplementedError


# @ray.remote(num_gpus=1)
class GrammarCoverage_S2S_Study(object):

    def __init__(self,
                 batch_size, epochs, steps_per_epoch, dataset_name,
                 nb_samples_test=10000, do_tests=True):

        self.__dict__.update(batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch,
                             dataset_name=dataset_name, nb_samples_test=nb_samples_test, do_tests=do_tests)

        ############################################################
        #     Load Grammar and Dataset
        ############################################################

        # Generate experiments folders and config file
        self.time_string = strftime("%Y_%m_%d_at_%H_%M_%S", localtime())

        self.base_config = {}
        self.base_config['times'] = {}
        self.base_config['times'].update({'start __init__': self.time_string})

        self.base_DataClass = DataClass(nb_samples_test=nb_samples_test, dataset_name=dataset_name)
        self.base_config.update(self.base_DataClass.config)

        # Load the grammar
        self.vocabulary = self.base_DataClass.vocabulary

        self.previousEmbeddingFilename = 'None'

    def trainAndPlot(self, study_model=['embedding_arithmetic_16d', 'previously_saved_model']):

        K.clear_session()
        ops.reset_default_graph()

        self.DataClass = copy.deepcopy(self.base_DataClass)
        self.config = copy.deepcopy(self.base_config)
        self.config['results'] = {}

        self.config['times'].update({'start ' + str(study_model):
                                         strftime("%Y-%m-%d-at-%H:%M:%S", localtime())})

        self.config['study_this'] = str(study_model)
        self.config['study_previous'] = self.previousEmbeddingFilename

        MethodName = self.dataset_name + '_' + study_model[0]
        self.DataClass.get_data('MethodName', MethodName)

        experiment_folder = make_directories(self.time_string + '-' + MethodName)
        self.DataClass.get_data('experiment_folder', experiment_folder)

        # copy files so one can understand how they were after the experiment is run
        copyfile(os.path.join(CDIR, 'main.py'), experiment_folder + '/files/main.py')
        list_files = [file for file in os.listdir(os.path.join(CDIR, 'models/')) if '.py' in file]
        for file in list_files: copyfile(os.path.join(CDIR, 'models/' + file), experiment_folder + '/files/' + file)

        logger.info('\n\nStarted.')
        logger.info(experiment_folder)
        logger.info(str(study_model))
        logger.info('\n\n')

        elementsFilename = MethodName.split('_')

        try:
            lat_dim = int([e[:-1] for e in elementsFilename if e[-1] == 'd'][0])
        except Exception:
            raise ValueError('Unspecified latent dimension size!!')

        self.DataClass.get_data('lat_dim', lat_dim)

        ############################################################
        logger.info('\n\n Create Embedding \n\n')
        ############################################################

        embedding_filename = os.path.join(CDIR, experiment_folder + '/model/' + MethodName)
        begin = time.time()

        # check if the model has been trained before
        isAlreadyTrained, embedding = alreadyTrained(
            study_model[1],
            embedding_filename,
            self.previousEmbeddingFilename,
            self.DataClass)
        if isAlreadyTrained:
            LoadOrTrainEmbeddingTime = 'loaded from ' + study_model[1]
        else:
            embedding = getEmbedding(model_filename=embedding_filename,
                                     gzip_filename_train=self.DataClass.biasedFilename_train,
                                     gzip_filename_val=self.DataClass.biasedFilename_val,
                                     grammar=self.DataClass.grammar,
                                     epochs=self.epochs,
                                     steps_per_epoch=self.steps_per_epoch,
                                     batch_size=self.batch_size)
            end = time.time()
            LoadOrTrainEmbeddingTime = end - begin

        encoder = embedding.getEncoder()
        decoder = embedding.getDecoder()

        self.DataClass.get_data('encoder', encoder)
        self.DataClass.get_data('decoder', decoder)

        self.config['lat_dim'] = lat_dim
        self.config['MethodName'] = MethodName
        self.config['LoadOrTrainEmbeddingTime'] = LoadOrTrainEmbeddingTime
        self.config['batch_size'] = self.batch_size
        self.config['steps_per_epoch'] = self.steps_per_epoch
        self.config['epochs'] = self.epochs
        self.config['nb_samples_test'] = self.nb_samples_test

        SaveConfig(experiment_folder, self.config)

        if self.do_tests:
            begin_test = time.time()

            ############################################################
            logger.info('\n\n Tests: Model for Reconstruction \n\n')
            ############################################################

            study = ReconstructionAndGeneralizationStudy(self.DataClass, self.config)

            self.config = study.config
            SaveConfig(experiment_folder, self.config)

            self.DataClass = study.DataClass
            del study

            ############################################################
            logger.info('\n\n Tests: Model for Generation \n\n')
            ############################################################

            study = DecodingStudy(self.DataClass, self.config)

            self.config = study.config
            SaveConfig(experiment_folder, self.config)

            self.DataClass = study.DataClass
            del study
            SaveText(experiment_folder, self.DataClass)

            ###########################################################
            logger.info('\n\n Tests: Model for Encoding \n\n')
            ###########################################################

            study = EncodingStudy(self.DataClass, self.config)

            self.config = study.config
            SaveConfig(experiment_folder, self.config)

            self.DataClass = study.DataClass
            del study

            ############################################################
            logger.info('\n\n Visualizations \n\n')
            ############################################################

            PlotDecodingGrammarClustering(self.DataClass)
            for projection_method in ['PCA', 'tSNE', 'UMAP']:
                PlotResults(self.DataClass, projection_method=projection_method)
                PlotEncodingGrammarClustering(self.DataClass, projection_method=projection_method)
                PlotEncodingLengthClustering(self.DataClass, projection_method=projection_method)

            mergePlotsInSummary(self.DataClass.experiment_folder)

            ############################################################
            logger.info('\n\n Save Results \n\n')
            ############################################################

            memorized_ratio = CompareTexts(experiment_folder, self.DataClass)
            self.config['memorized_ratio'] = memorized_ratio

            self.config['times'].update({'end ' + str(study_model):
                                             strftime("%Y-%m-%d-at-%H:%M:%S", localtime())})

            end_test = time.time()
            TestTime = end_test - begin_test
            self.config['times'] = {'train': LoadOrTrainEmbeddingTime, 'test': TestTime}

            SaveConfig(experiment_folder, self.config)

            logger.info('\n\nAll done.')
            logger.info(experiment_folder)
            logger.info('\n\n')
            shutil.make_archive(experiment_folder, 'zip', experiment_folder)

            del self.DataClass
            del self.config
            del embedding
            self.previousEmbeddingFilename = embedding_filename


@ex.automain
def main(batch_size, epochs, steps_per_epoch, dataset_name, nb_samples_test, studies, do_tests, GPU, seed, _log):
    # ChooseGPU(GPU)
    setReproducible(seed)

    # if the dataset hasn't been generated, generate it
    studies = check_studies_are_unique(studies)

    # perform the studies chosen in the config
    GCSS = GrammarCoverage_S2S_Study(batch_size, epochs, steps_per_epoch, dataset_name, nb_samples_test, do_tests)
    for study in studies:
        GCSS.trainAndPlot(study)
