import tensorflow as tf
import numpy as np
from scriptify import scriptify
from dbify import dbify

import sys

sys.path.append("../")

from latent_space.models import VAE, AE, AE_Conv
from latent_space.utils import get_data

import tensorflow.keras.backend as K
from tensorflow.keras.losses import Loss
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.callbacks import Callback

from sklearn.preprocessing import MinMaxScaler
class UnnormalizedBinaryCrossentropy(Loss):
    def __init__(self, **kwargs):
        super().__init__()
        kwargs['from_logits'] = True
        self._loss = BinaryCrossentropy(**kwargs)

    def call(self, y_true, y_pred):
        return self._loss.call(tf.math.sigmoid(y_true), y_pred)

    def get_config(self):
        config = self._loss.get_config()
        return config

if __name__ == "__main__":

    @scriptify
    @dbify('project_invalidation', 'vae_training', skip_duplicates=True)
    def script(data,
               normalize=True,
               model_type='ae',
               arch='dense,d1024.d128',
               latent_dim=32,
               opt='adam',
               learning_rate=1.e-4,
               batch_size=512,
               epochs=100,
               loss='binary_crossentropy',
               gpu=0):

        gpus = tf.config.experimental.list_physical_devices('GPU')
        tf.config.experimental.set_visible_devices(gpus[gpu], 'GPU')
        device = gpus[gpu]

        for device in tf.config.experimental.get_visible_devices('GPU'):
            tf.config.experimental.set_memory_growth(device, True)

        (X_train, y_train), (X_test, y_test), _ = get_data(data)

        data_range = [X_train.min(), X_train.max()]

        if normalize:
            scaler = MinMaxScaler(feature_range=(0, 1))
            scaler.fit(X_train)
            X_train = scaler.transform(X_train)
            X_test = scaler.transform(X_test)
            data_range = [X_train.min(), X_train.max()]
            vae_data_range =  [X_train.min(), X_train.max()]
        else:
            vae_data_range = data_range

        arch, arch_string = arch.split(',')[0], arch.split(',')[1]

        if arch == 'dense':
            X_train = np.reshape(X_train, (X_train.shape[0], -1))
            X_test = np.reshape(X_train, (X_train.shape[0], -1))

            input_shape = (X_train.shape[1], )
        else:
            if len(X_train.shape) < 4:
                X_train = X_train[:, :, :, None]
                X_test = X_test[:, :, :, None]
            input_shape = X_train.shape[1:]

        optimizer = tf.keras.optimizers.get(opt).__class__(
            learning_rate=learning_rate)

        if loss == 'unnormalized_bce':
            loss_fn = UnnormalizedBinaryCrossentropy()
        else:
            loss_fn = loss

        if model_type.lower() == 'ae':
            model = AE(input_shape,
                       latent_dim,
                       arch_string=arch_string,
                       normalize=vae_data_range)
        elif model_type.lower() == 'vae':
            model = VAE(input_shape,
                        latent_dim,
                        arch_string=arch_string,
                        normalize=vae_data_range)

        model.compile(optimizer=optimizer,
                      loss=loss_fn,
                      metrics=[tf.keras.losses.MeanSquaredError()])
        model.fit(X_train,
                  X_train,
                  batch_size=batch_size,
                  validation_data=(X_test, X_test),
                  epochs=epochs)

        model.save_weights("weights/" + model_type + "_" + arch + "_" + data +
                           "_" + loss + ".h5")

        return {
            'final_weights':
            "weights/" + model_type + "_" + arch + "_" + data + "_" + loss +
            ".h5",
            'data_min':
            float(data_range[0]),
            'data_max':
            float(data_range[1])
        }