import time
import argparse
from pathlib import Path
import yaml
import logging
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import pickle
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import train_test_split
import scipy.io as scio
from sklearn.cluster import KMeans
import os
from sklearn.metrics.cluster import normalized_mutual_info_score
import uuid
from ast import literal_eval as make_tuple
from keras.datasets import cifar10
from UTKFace.numpy import UTKFaceDataLoader
from UTKFace.utils.labels import Label

from source.data import DataGenerator

tfd = tfp.distributions
tfkl = tf.keras.layers
tfpl = tfp.layers
tfk = tf.keras

import source.utils as utils
from source.model import GMMVAE

# project-wide constants:
ROOT_LOGGER_STR = "ConstrainedVADE"
LOGGER_RESULT_FILE = "logs.txt"
CHECKPOINT_PATH = 'models'  # "autoencoder/cp.ckpt"

logger = logging.getLogger(ROOT_LOGGER_STR + '.' + __name__)

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# Singleton loader
utkface_loader = None


def get_data(args, configs):
    if args.data == 'MNIST':
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
        x_train = x_train / 255.
        x_train = np.reshape(x_train, (-1, 28 * 28))
        x_test = x_test / 255.
        x_test = np.reshape(x_test, (-1, 28 * 28))

    elif args.data == 'fMNIST':
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
        x_train = x_train / 255.
        x_train = np.reshape(x_train, (-1, 28 * 28))
        x_test = x_test / 255.
        x_test = np.reshape(x_test, (-1, 28 * 28))

    elif args.data == 'Reuters':
        file_train = "dataset/reuters/reutersidf10k_train.npy"
        file_test = "dataset/reuters/reutersidf10k_test.npy"
        rtk10k_train = np.load(file_train, allow_pickle=True).item()
        rtk10k_test = np.load(file_test, allow_pickle=True).item()
        x_train = rtk10k_train['data']
        y_train = rtk10k_train['label']
        x_test = rtk10k_test['data']
        y_test = rtk10k_test['label']

    elif args.data == 'har':
        data = scio.loadmat('dataset/har/HAR.mat')
        X = data['X']
        X = X.astype('float32')
        Y = data['Y'] - 1
        X = X[:10200]
        Y = np.reshape(Y[:10200], (-1))
        x_train = X[:8000]
        y_train = Y[:8000]
        x_test = X[8000:]
        y_test = Y[8000:]

    elif args.data == 'cifar10':
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        train_norm = x_train.astype('float32')
        test_norm = x_test.astype('float32')
        # normalize to range 0-1
        x_train = train_norm / 255.0
        x_test = test_norm / 255.0
        y_train = np.reshape(y_train, (-1))
        y_test = np.reshape(y_test, (-1))

    elif args.data == 'stl10':
        x_train = np.load("dataset/stl10/embedding_train.npy")
        y_train = np.load("dataset/stl10/label_train.npy")
        x_test = np.load("dataset/stl10/embedding_test.npy")
        y_test = np.load("dataset/stl10/label_test.npy")
        xx = np.load("dataset/stl10/embedding_unlabeled.npy")
        x_train = np.concatenate([x_train, xx[:70000]], axis=0)
        x_train = x_train.astype('float32')
        x_train = np.reshape(x_train, (-1, 27648))
        x_test = x_test.astype('float32')
        x_test = np.reshape(x_test, (-1, 27648))
        # normalize to range 0-1
        x_train = x_train / x_train.max()
        x_test = x_test / x_train.max()
        y_train = np.concatenate([y_train, np.ones((len(xx[:70000])), dtype=np.int8)], axis=0)

    elif args.data == 'utkface':
        global utkface_loader
        utkface_image_path = "dataset/utkface"

        if utkface_loader is None:
            if configs['data']['label'] == "age":
                label_type = Label.Type.AGE
            elif configs['data']['label'] == "age_bins":
                label_type = Label.Type.AGE_BINS_UNIFORM
            elif configs['data']['label'] == "age_bins_manual":
                label_type = Label.Type.AGE_BINS_MANUAL
            elif configs['data']['label'] == "gender":
                label_type = Label.Type.GENDER
            elif configs['data']['label'] == "ethnicity":
                label_type = Label.Type.ETHNICITY
            else:
                label_type = Label.Type.NONE

            if label_type == Label.Type.ETHNICITY:
                utkface_loader = UTKFaceDataLoader(label=label_type, filter_age=(18, 50), filter_ethnicity=[0, 1, 2, 3],
                                                   resize=(64, 64), images_path=utkface_image_path)
            elif label_type != Label.Type.AGE_BINS_UNIFORM:
                utkface_loader = UTKFaceDataLoader(label=label_type, filter_age=None, resize=(64, 64),
                                                   images_path=utkface_image_path)
            else:
                num_bins = configs['training']['num_clusters']
                utkface_loader = UTKFaceDataLoader(label=label_type, filter_age=None, num_age_bins=num_bins,
                                                   resize=(64, 64), images_path=utkface_image_path)

        X, y = utkface_loader.get_data()

        X = X / 255.

        x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)

    return x_train, x_test, y_train, y_test


def loss_GMMVAE_mnist(inp, x_decoded_mean):
    x = inp
    loss = 784 * tf.keras.losses.BinaryCrossentropy()(x, x_decoded_mean)
    return loss


def loss_GMMVAE_reuters(inp, x_decoded_mean):
    x = inp
    loss = 2000 * tf.keras.losses.MeanSquaredError()(x, x_decoded_mean)
    return loss


def loss_GMMVAE_har(inp, x_decoded_mean):
    x = inp
    loss = 5 * 561 * tf.keras.losses.MeanSquaredError()(x, x_decoded_mean)
    return loss


def loss_GMMVAE_cifar10(inp, x_decoded_mean):
    x = inp
    # loss = 3072 * tf.keras.losses.MeanSquaredError()(x, x_decoded_mean)
    loss = 3072 * tf.keras.losses.BinaryCrossentropy()(x, x_decoded_mean)
    return loss


def loss_GMMVAE_stl10(inp, x_decoded_mean):
    x = inp
    loss = 27648 * tf.keras.losses.MeanSquaredError()(x, x_decoded_mean)  # 2304
    return loss


def loss_GMMVAE_utkface(inp, x_decoded_mean):
    x = inp
    # loss = 12288 * tf.keras.losses.MeanSquaredError()(x, x_decoded_mean)
    loss = 12288 * tf.keras.losses.BinaryCrossentropy()(x, x_decoded_mean)
    return loss


def accuracy_metric(inp, p_c_z):
    y = inp
    y_pred = tf.math.argmax(p_c_z, axis=-1)
    return tf.numpy_function(utils.cluster_acc, [y, y_pred], tf.float64)


def pretrain(model, args, ex_name, configs):
    input_shape = configs['training']['inp_shape']
    num_clusters = configs['training']['num_clusters']

    if configs['data']['data_name'] in ["cifar10", "utkface"]:
        if configs['training']['type'] in ["CNN", "VGG"]:
            input_shape = make_tuple(input_shape)

    # Get the AE from the model
    input = tfkl.Input(shape=input_shape)

    if configs['training']['type'] == "FC":
        f = tfkl.Flatten()(input)
        e1 = model.encoder.dense1(f)
        e2 = model.encoder.dense2(e1)
        e3 = model.encoder.dense3(e2)
        z = model.encoder.mu(e3)
        d1 = model.decoder.dense1(z)
        d2 = model.decoder.dense2(d1)
        d3 = model.decoder.dense3(d2)
        dec = model.decoder.dense4(d3)
    elif configs['training']['type'] == "CNN":
        e1 = model.encoder.conv1(input)
        e2 = model.encoder.conv2(e1)
        f = tfkl.Flatten()(e2)
        z = model.encoder.mu(f)
        d1 = model.decoder.dense(z)
        d2 = model.decoder.reshape(d1)
        d3 = model.decoder.convT1(d2)
        d4 = model.decoder.convT2(d3)
        d5 = model.decoder.convT3(d4)
        dec = tf.sigmoid(d5)
    elif configs['training']['type'] == "VGG":
        enc = input
        for block in model.encoder.layers:
            enc = block(enc)
        f = tfkl.Flatten()(enc)
        z = model.encoder.mu(f)
        d_dense = model.decoder.dense(z)
        d_reshape = model.decoder.reshape(d_dense)
        dec = d_reshape
        for block in model.decoder.layers:
            dec = block(dec)
        dec = model.decoder.convT(dec)
        dec = tf.sigmoid(dec)

    autoencoder = tfk.Model(inputs=input, outputs=dec)

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)  # , decay=args.decay)
    if args.data == 'MNIST' or args.data == 'fMNIST' or args.data == 'cifar10' or args.data == 'utkface':
        autoencoder.compile(optimizer=optimizer, loss="binary_crossentropy")
    else:
        autoencoder.compile(optimizer=optimizer, loss="mse")
    autoencoder.summary()
    x_train, x_test, y_train, y_test = get_data(args, configs)
    X = np.concatenate((x_train, x_test))
    Y = np.concatenate((y_train, y_test))

    # If the model should be run from scratch:
    if args.pretrain:
        os.makedirs("pretrain/autoencoder_tmp", exist_ok=True)
        os.makedirs("pretrain/gmm_tmp", exist_ok=True)

        print('\n******************** Pretraining **************************')
        cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath="pretrain/autoencoder_tmp/" + ex_name + "/cp.ckpt",
                                                         save_weights_only=True, verbose=1)
        autoencoder.fit(X, X, epochs=args.epochs_pretrain, batch_size=32, callbacks=cp_callback)

        encoder = model.encoder
        input = tfkl.Input(shape=input_shape)
        z, _ = encoder(input)
        z_model = tf.keras.models.Model(inputs=input, outputs=z)
        z = z_model.predict(X)

        estimator = GaussianMixture(n_components=num_clusters, covariance_type='diag', n_init=3)
        estimator.fit(z)
        pickle.dump(estimator, open("pretrain/gmm_tmp/" + ex_name + "_gmm_save.sav", 'wb'))

        print('\n******************** Pretraining Done**************************')
    else:
        if args.data == 'MNIST':
            autoencoder.load_weights("pretrain/MNIST/autoencoder/cp.ckpt")
            estimator = pickle.load(open("pretrain/MNIST/gmm_save.sav", 'rb'))
            print('\n******************** Loaded MNIST Pretrain Weights **************************')
        elif args.data == 'fMNIST':
            autoencoder.load_weights("pretrain/fMNIST/autoencoder/cp.ckpt")
            estimator = pickle.load(open("pretrain/fMNIST/gmm_save.sav", 'rb'))
            print('\n******************** Loaded fMNIST Pretrain Weights **************************')
        elif args.data == 'Reuters':
            autoencoder.load_weights("pretrain/Reuters/autoencoder/cp.ckpt")
            estimator = pickle.load(open("pretrain/Reuters/gmm_save.sav", 'rb'))
        elif args.data == 'har':
            autoencoder.load_weights("pretrain/HHAR/autoencoder/cp.ckpt")
            estimator = pickle.load(open("pretrain/HHAR/gmm_save.sav", 'rb'))
        else:
            print('\nPretrained weights for {} not available, please rerun with \'--pretrain True option\''.format(
                args.data))
            exit(1)

    encoder = model.encoder
    input = tfkl.Input(shape=input_shape)
    z, _ = encoder(input)
    z_model = tf.keras.models.Model(inputs=input, outputs=z)

    # Assign weights to GMM mixtures of VaDE
    mu_samples = estimator.means_
    sigma_samples = estimator.covariances_
    model.c_mu.assign(mu_samples)
    model.c_sigma.assign(sigma_samples)

    yy = estimator.predict(z_model.predict(X))
    acc = utils.cluster_acc(yy, Y)
    pretrain_acc = acc
    print('\nPretrain accuracy: ' + str(acc))

    return model, pretrain_acc


def run_experiment(args, configs, loss):
    # Set paths
    timestr = time.strftime("%Y%m%d-%H%M%S")
    ex_name = "{}_{}".format(str(timestr), uuid.uuid4().hex[:5])
    experiment_path = args.results_dir / configs['data']['data_name'] / ex_name
    experiment_path.mkdir(parents=True)

    x_train, x_test, y_train, y_test = get_data(args, configs)

    acc_tot = []
    nmi_tot = []

    for i in range(args.runs):
        model = GMMVAE(**configs['training'])

        optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr, decay=args.decay)

        if args.save_model:
            checkpoint_path = CHECKPOINT_PATH + '/' + configs['data']['data_name'] + '/' + ex_name
            cp_callback = [tf.keras.callbacks.TensorBoard(log_dir='logs/' + ex_name),
                           tf.keras.callbacks.ModelCheckpoint(
                               filepath=checkpoint_path,
                               verbose=1,
                               save_weights_only=True,
                               period=100)]
        else:
            cp_callback = [tf.keras.callbacks.TensorBoard(log_dir='logs/' + ex_name)]

        model.compile(optimizer, loss={"output_1": loss}, metrics={"output_4": accuracy_metric})

        # pretrain model
        model, pretrain_acc = pretrain(model, args, ex_name, configs)

        if args.q > 0:
            alpha = 3500 * np.log((1 - args.q) / args.q)
        else:
            alpha = args.alpha

        # create data generators
        train_gen = DataGenerator(x_train, y_train, num_constrains=args.num_constrains, alpha=alpha, q=args.q,
                                  batch_size=args.batch_size, ml=args.ml)
        if args.data == 'stl10':
            train_gen = DataGenerator(x_train, y_train, num_constrains=args.num_constrains, alpha=alpha, q=args.q,
                                      batch_size=args.batch_size, ml=args.ml, l=5000)
        test_gen = DataGenerator(x_test, y_test)

        # fit model
        model.fit(train_gen, validation_data=test_gen, epochs=args.num_epochs, callbacks=cp_callback, workers=15)

        # results
        rec, z_sample, p_z_c, p_c_z = model.predict([x_train, np.zeros(len(x_train))])
        yy = np.argmax(p_c_z, axis=-1)
        acc = utils.cluster_acc(y_train, yy)
        nmi = normalized_mutual_info_score(y_train, yy)

        if args.data == 'MNIST':
            f = open("results_MNIST.txt", "a+")
        elif args.data == 'fMNIST':
            f = open("results_fMNIST.txt", "a+")
        elif args.data == 'Reuters':
            f = open("results_reuters.txt", "a+")
        elif args.data == 'har':
            f = open("results_har.txt", "a+")
        elif args.data == 'cifar10':
            f = open("results_cifar.txt", "a+")
        elif args.data == 'stl10':
            f = open("results_stl.txt", "a+")
        elif args.data == 'utkface':
            f = open("results_utkface.txt", "a+")
        f.write("Epochs= %d, num_constrains= %d, ml= %d, alpha= %d, batch_size= %d, learning_rate= %f, q= %f, "
                "pretrain_e= %d,  "
                % (args.num_epochs, args.num_constrains, args.ml, alpha, args.batch_size, args.lr, args.q,
                   args.epochs_pretrain))

        f.write("decay= %f, name= %s. " % (args.decay, ex_name))

        f.write("Pretrain accuracy: %f , " % (pretrain_acc))
        f.write("Accuracy train: %f, NMI: %f. " % (acc, nmi))

        rec, z_sample, p_z_c, p_c_z = model.predict([x_test, np.zeros(len(x_test))])
        yy = np.argmax(p_c_z, axis=-1)
        acc = utils.cluster_acc(y_test, yy)
        nmi = normalized_mutual_info_score(y_test, yy)

        acc_tot.append(acc)
        nmi_tot.append(nmi)

        f.write("Accuracy test: %f, NMI: %f.\n " % (acc, nmi))
        f.close()
        print(str(acc))
        print(str(nmi))

    if args.runs > 1:

        acc_tot = np.array(acc_tot)
        nmi_tot = np.array(nmi_tot)

        if args.data == 'MNIST':
            f = open("evaluation_MNIST.txt", "a+")
        elif args.data == 'fMNIST':
            f = open("evaluation_fMNIST.txt", "a+")
        elif args.data == 'Reuters':
            f = open("evaluation_reuters.txt", "a+")
        elif args.data == 'har':
            f = open("evaluation_har.txt", "a+")
        elif args.data == 'cifar10':
            f = open("evaluation_cifar.txt", "a+")
        elif args.data == 'stl10':
            f = open("evaluation_stl.txt", "a+")
        elif args.data == 'utkface':
            f = open("evaluation_utkface.txt", "a+")

        f.write("Epochs= %d, num_constrains= %d, ml= %d, alpha= %d, batch_size= %d, learning_rate= %f, q= %f, "
                "pretrain_e= %d,  "
                % (args.num_epochs, args.num_constrains, args.ml, alpha, args.batch_size, args.lr, args.q,
                   args.epochs_pretrain))

        f.write(
            "decay= %f, runs= %d, name= %s. "
            % (args.decay, args.runs, ex_name))

        f.write("Pretrain accuracy: %f , " % (pretrain_acc))
        f.write("Accuracy: %f std %f, NMI: %f std %f. \n" % (
            np.mean(acc_tot), np.std(acc_tot), np.mean(nmi_tot), np.std(nmi_tot)))


def main():
    project_dir = Path(__file__).absolute().parent

    parser = argparse.ArgumentParser()

    # parameters of the model
    parser.add_argument('--data',
                        default='MNIST',
                        type=str,
                        choices=['MNIST', 'fMNIST', 'Reuters', 'cifar10', 'stl10', 'har', 'utkface'],
                        help='specify the data (MNIST, fMNIST, Reuters, cifar10, stl10, har, utkface)')
    parser.add_argument('--num_epochs',
                        default=1000,
                        type=int,
                        help='specify the number of epochs')
    parser.add_argument('--num_constrains',
                        default=70000,
                        type=int,
                        help='specify the number of constrains')
    parser.add_argument('--batch_size',
                        default=1024,
                        type=int,
                        help='specify the batch size')
    parser.add_argument('--alpha',
                        default=10000,
                        type=int,
                        help='specify alpha, the weight importance of the constraints (higher means higher confidence)')
    parser.add_argument('--q',
                        default=0,
                        type=float,
                        help='specify the flip probability of the labels')
    parser.add_argument('--lr',
                        default=0.001,
                        type=float,
                        help='specify learning rate')
    parser.add_argument('--decay',
                        default=0.00001,
                        type=float,
                        help='specify decay')
    parser.add_argument('--ml',
                        default=0,
                        type=int,
                        choices=[0, 1, -1],
                        help='0: random choice, 1: only must-link, -1: only cannot-link')
    parser.add_argument('--w',
                        default=1,
                        type=float,
                        help='w')

    # other parameters
    parser.add_argument('--runs',
                        default=1,
                        type=int,
                        help='number of runs, the results will be averaged')
    parser.add_argument('--results_dir',
                        default=project_dir / 'experiments',
                        type=lambda p: Path(p).absolute(),
                        help='specify the folder where the results get saved')
    parser.add_argument('--pretrain', default=False, type=bool,
                        help='True to pretrain the autoencoder, False to use pretrained weights')
    parser.add_argument('--epochs_pretrain', default=10, type=int,
                        help='Specify the number of pre-training epochs')
    parser.add_argument('--save_model', default=False, type=bool,
                        help='True to save the model')

    args = parser.parse_args()

    if args.data == "MNIST" or args.data == "fMNIST":
        config_path = project_dir / 'configs' / 'MNIST.yml'
        loss = loss_GMMVAE_mnist
    elif args.data == "Reuters":
        config_path = project_dir / 'configs' / 'Reuters.yml'
        loss = loss_GMMVAE_reuters
    elif args.data == 'har':
        config_path = project_dir / 'configs' / 'har.yml'
        loss = loss_GMMVAE_har
    elif args.data == "cifar10":
        config_path = project_dir / 'configs' / 'cifar10.yml'
        loss = loss_GMMVAE_cifar10
    elif args.data == "stl10":
        config_path = project_dir / 'configs' / 'stl10.yml'
        loss = loss_GMMVAE_stl10
    elif args.data == "utkface":
        config_path = project_dir / 'configs' / 'utkface.yml'
        loss = loss_GMMVAE_utkface

    with config_path.open(mode='r') as yamlfile:
        configs = yaml.safe_load(yamlfile)

    run_experiment(args, configs, loss)


if __name__ == "__main__":
    main()
