
import sys
import numpy as np
np.random.seed(2)
import tensorflow as tf
# tf.random.set_random_seed(2)
sys.path.append('./precision_recall_distributions')
import os
from models.rae import make_raes
from models.std_vae import std_vae
from models.wae_mmd import wae_mmd
from dataloaders.dataloader import DataLoader
from configurations import config
from my_utility import config_parser
import keras.backend as K
from my_utility.my_callbacks import LatentSpaceSampler
from my_utility import save_batches_of_images
from my_utility import interpolations
from my_utility import fid_from_dir_computer
from my_utility import estimate_density_and_sample
import time
from precision_recall_distributions import prd_from_image_folders as prd


def predict_2stage(encoder, decoder, qz_sampler, recon_original):
    if len(encoder.outputs) > 1:
        return decoder.predict(qz_sampler.reconstruct(encoder.predict(recon_original)[0]))
    else:
        return decoder.predict(qz_sampler.reconstruct(encoder.predict(recon_original)))


pairs_interpolation = {'MNIST': [[537, 9749],
                                 [1327, 6570],
                                 [1703, 4717],
                                 [1838, 1399],
                                 [2028, 8637],
                                 [2543, 5672],
                                 [4118, 4817],
                                 [4471, 170],
                                 [4656, 8901],
                                 [5134, 2283],
                                 [5320, 912],
                                 [5676, 2381],
                                 [5977, 2686],
                                 [5983, 3868],
                                 [6816, 9143],
                                 [7409, 1415],
                                 [8027, 1636],
                                 [8739, 5640],
                                 [8960, 4306],
                                 [9316, 825]],
                       'CIFAR_10': [[537, 9749],
                                    [1327, 6570],
                                    [1703, 4717],
                                    [1838, 1399],
                                    [2028, 8637],
                                    [2543, 5672],
                                    [4118, 4817],
                                    [4471, 170],
                                    [4656, 8901],
                                    [5134, 2283],
                                    [5320, 912],
                                    [5676, 2381],
                                    [5977, 2686],
                                    [5983, 3868],
                                    [6816, 9143],
                                    [7409, 1415],
                                    [8027, 1636],
                                    [8739, 5640],
                                    [8960, 4306],
                                    [9316, 825]],
                       'CELEBA': [[190, 1526],
                                  [526, 15140],
                                  [1185, 1384],
                                  [3328, 9392],
                                  [5832, 8602],
                                  [7674, 10954],
                                  [8481, 787],
                                  [8765, 127],
                                  [9230, 11958],
                                  [10572, 16050],
                                  [10856, 12309],
                                  [11047, 1344],
                                  [11228, 11558],
                                  [11388, 14825],
                                  [11487, 17382],
                                  [13806, 6168],
                                  [15064, 15036],
                                  [15798, 14732],
                                  [17953, 7791],
                                  [18488, 16407],]}


def main():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 
    # Setting up logging
    maj_cfg_idx, minor_cfg_idx = config_parser.get_config_idxs(int(sys.argv[1]), config.configurations)
    log_root = config.configurations[maj_cfg_idx][0]['log_root']
    log_root = os.path.join(log_root, str(maj_cfg_idx))
    log_dir = os.path.join(log_root, config.configurations[maj_cfg_idx][minor_cfg_idx]['expt_name'] +\
                                     '_' + str(minor_cfg_idx))

    model_name = config.configurations[maj_cfg_idx][0]['base_model_name'] + '_' + \
                 config.configurations[maj_cfg_idx][0]['dataset_name'] + '.h5'
    model_path = os.path.join(log_dir, model_name)

    expt_name = config.configurations[maj_cfg_idx][minor_cfg_idx]['expt_name']
    n_components = config.configurations[maj_cfg_idx][minor_cfg_idx]['n_components']

    # Preparing data Generator
    batch_size = config.configurations[maj_cfg_idx][0]['batch_size']
    dataloader = DataLoader(batch_size=batch_size)
    (train_generator, validation_generator, test_generator), input_shape, (train_steps, val_steps, test_steps) = \
        dataloader.get_data_loader(dataset_name=config.configurations[maj_cfg_idx][0]['dataset_name'], shuffle=False)

    # Preparing model
    if config.configurations[maj_cfg_idx][0]['base_model_name'].upper().find('RAE') >= 0:
        encoder, decoder, auto_encoder = make_raes.get_vae(input_shape, config.configurations, maj_cfg_idx,
                                                                 minor_cfg_idx)
        compute_z_cov = True
    elif config.configurations[maj_cfg_idx][0]['base_model_name'].upper().find('WAE') >= 0:
        encoder, decoder, auto_encoder = wae_mmd.get_wae(input_shape, config.configurations, maj_cfg_idx,
                                                         minor_cfg_idx)
        compute_z_cov = False
    elif config.configurations[maj_cfg_idx][0]['base_model_name'].upper().find('STD_VAE') >= 0:
        encoder, decoder, auto_encoder = std_vae.get_vae(input_shape, config.configurations, maj_cfg_idx,
                                                         minor_cfg_idx)
        compute_z_cov = False
    else:
        raise NotImplementedError("No implemntation for " +
                                  str(config.configurations[maj_cfg_idx][0]['base_model_name']) + " found.")


    multi_output_enc = False
    if len(encoder.outputs) > 1:
        multi_output_enc = True

    # Generatig sampled, reconstructed and interpolated images
    batches = 100
    auto_encoder.load_weights(model_path+'_best')


    # Save embeddings

    z_dims = K.get_variable_shape(encoder.outputs[0])[-1]

    # ## Training embeddings
    if not os.path.exists(os.path.join(log_dir, model_name[:-2] + '_' + expt_name + '_' + '_training_embedding.npz')):
        zs_trn = np.zeros((int((train_steps)*batch_size), z_dims))
        zs_trn_log_sigma = np.zeros((int((train_steps) * batch_size), z_dims))
        for i in range(int(train_steps-1)):
            (x, _) = train_generator.next()
            if multi_output_enc:
                zs_trn[i * batch_size:(i + 1) * batch_size], zs_trn_log_sigma[i * batch_size:(i + 1) * batch_size] = \
                    encoder.predict(x)
            else:
                zs_trn [i*batch_size:(i+1)*batch_size] = encoder.predict(x)

    else:
        zs_trn = np.load(os.path.join(log_dir, model_name[:-2] + '_' + expt_name + '_' + '_training_embedding.npz'))['zs']

    np.random.seed(2)
    tf.random.set_random_seed(2)
    tf.compat.v1.random.set_random_seed(2)

    dataset_dir = dataloader.get_data_dir()
    # # # # # save sampled images
    sampled_images = decoder.predict(np.random.normal(loc=0.0, scale=1.0, size=(10000, zs_trn.shape[-1])))

    save_batches_of_images.save_set_of_images(os.path.join(log_dir, 'one_gaussian_sampled'), sampled_images)
    
    qz_est_name_list = ['GMM_10']
    qz_samplers = []
    for estimator_name in qz_est_name_list:
        ## Q(z) estimation
        start = time.time()
        if estimator_name.upper().find("AUX_VAE") >= 0:
            second_stage_beta = config.configurations[maj_cfg_idx][minor_cfg_idx]['second_stage_beta']
        else:
            second_stage_beta = 0

        qz_sampler = estimate_density_and_sample.DensityEstimator(training_set=zs_trn,
                                                                  method_name=estimator_name,
                                                                  n_components=n_components,
                                                                  log_dir=log_dir, second_stage_beta=second_stage_beta)
        if estimator_name.upper().find("AUX_VAE") >= 0:
            if os.path.exists(model_path[:-3] + "_" + estimator_name + '_2nd_stage.h5'):
                qz_sampler.fitorload(model_path[:-3] + "_" + estimator_name + '_2nd_stage.h5')
            else:
                qz_sampler.fitorload()
                qz_sampler.save(model_path[:-3] + "_" + estimator_name + '_2nd_stage.h5')
        else:
            qz_sampler.fitorload()
            qz_sampler.save(model_path[:-3] + "_" + estimator_name + '_2nd_stage')
            print ("Time taken to fit " + str(time.time() - start))

        start = time.time()
        print("Sampling using " + estimator_name)
        zs = qz_sampler.get_samples(n_samples=10000)
        sampled_images = decoder.predict(zs)

        save_batches_of_images.save_set_of_images(os.path.join(log_dir, estimator_name+'_sampled'), sampled_images)

if __name__ == "__main__":
    main()
