from tensorflow.keras.layers import Concatenate, Dense, Flatten, Input, TimeDistributed, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.losses import categorical_crossentropy
import tensorflow as tf
from tensorflow_probability import distributions as tfd
import numpy as np

from modules.latent_space import latentspace
from modules.vae.reconstruction_losses import gaussian_loss
from modules.vae import architectures


class ShapeCultureVAE:
    """" VAE that disentangles (r) rotations, (s) shape properties, and (c) culture properties """

    def __init__(self, enc_dec_architecture, separate_encoders, input_shape,
                 latent_dim_s, latent_dim_c, reconstr_loss_function=gaussian_loss(), n_r=12, n_s=8, n_c=6,
                 weight_kl=1.0, weight_clf=1.0, weight_kl_posterior_prior=1.0,
                 weight_dist_to_avg_r=1.0, weight_dist_to_avg_s=1.0, weight_dist_to_avg_c=1.0):
        self.input_shape = input_shape
        self.enc_dec_architecture = enc_dec_architecture
        self.n_r = n_r
        self.n_s = n_s
        self.n_c = n_c
        self.reconstr_loss_function = reconstr_loss_function

        # setup architectures
        self.enc_r, self.decoder = self.architectures_function(*self.input_shape)
        if separate_encoders:
            self.enc_s = self.architectures_function(*self.input_shape)[0]
            self.enc_c = self.architectures_function(*self.input_shape)[0]
        else:
            self.enc_s = self.enc_r
            self.enc_c = self.enc_r
        self.clf_s = Dense(self.n_s, activation="softmax")
        self.clf_c = Dense(self.n_c, activation="softmax")

        # decoder input shape
        units = int(self.decoder.input.shape[1])
        self.dec_in_layer = Dense(units, activation="relu", name="dec_in")  # layer must be used in all models

        # SETUP TRANSFORMATIONS
        transformations = np.linspace(0, 2*np.pi, num=self.n_r, endpoint=False)
        transformations = np.expand_dims(transformations, axis=-1)
        transformations = np.expand_dims(transformations, axis=0)
        self.transformations = tf.constant(transformations, name="transformations")
        # TODO: this has shape (1, 12, 1), but should be (batch_size, 12, 1), check if broadcasting solves this

        # SETUP LATENT SPACES
        ls_r = latentspace.HyperSphericalLatentSpace(1, name="ls_r", log_t_limit=(-10.0, -5.0))
        ls_s = latentspace.GaussianLatentSpace(latent_dim_s, name="ls_s")
        ls_c = latentspace.GaussianLatentSpace(latent_dim_c, name="ls_c")
        ls_s_prior = latentspace.GaussianLatentSpace(latent_dim_s, name="ls_s_prior")
        ls_c_prior = latentspace.GaussianLatentSpace(latent_dim_c, name="ls_c_prior")

        # SETUP FULL ENCODER
        x_in = Input(shape=(self.n_r,) + self.input_shape)  # shape (n_data, n_r, *data_shape)
        h_enc_r = TimeDistributed(self.enc_r, name="h_enc_r_s")(x_in)  # shape (n_data, n_r, hidden_shape)
        h_enc_s = TimeDistributed(self.enc_s, name="h_enc_s_s")(x_in)
        h_enc_c = TimeDistributed(self.enc_c, name="h_enc_c_s")(x_in)
        z_r_params = ls_r.get_params(h_enc_r)
        z_s_params = ls_s.get_params(h_enc_s)
        z_c_params = ls_c.get_params(h_enc_c)
        z_r_sample = ls_r.sample_layer(z_r_params)  # shape (n_data, n_r, latent_dim)
        z_s_sample = ls_s.sample_layer(z_s_params)
        z_c_sample = ls_c.sample_layer(z_c_params)
        z_r_sample_anchored = ls_r.inverse_transform_layer([z_r_sample, self.transformations])
        z_r_sample_anchored_avg = ls_r.avg_layer(z_r_sample_anchored)
        z_r_sample_avg = ls_r.transform_layer([z_r_sample_anchored_avg, self.transformations])
        z_s_sample_avg = ls_s.avg_layer(z_s_sample)
        z_c_sample_avg = ls_c.avg_layer(z_c_sample)

        # SETUP FULL DECODER
        samples_in = [Input(batch_shape=K.int_shape(sample))
                      for sample in [z_r_sample_avg, z_s_sample_avg, z_c_sample_avg]]
        samples_concat = Concatenate(name="samples_concat")(samples_in)
        dec_in = self.dec_in_layer(samples_concat)
        dec_out = TimeDistributed(self.decoder, name="dec_out")(dec_in)
        self.full_decoder = Model(samples_in, dec_out)

        x_out = self.full_decoder([z_r_sample_avg, z_s_sample_avg, z_c_sample_avg])

        # SETUP CLASSIFIERS
        y_s_pred = self.clf_s(z_s_sample_avg)
        y_c_pred = self.clf_c(z_c_sample_avg)

        # SETUP CONDITIONAL PRIORS
        y_s_in = Input(shape=(self.n_s,))
        y_c_in = Input(shape=(self.n_c,))
        y_s_repeated = K.expand_dims(y_s_in, axis=-2)
        y_s_repeated = K.repeat_elements(y_s_repeated, rep=n_r, axis=-2)
        y_c_repeated = K.expand_dims(y_c_in, axis=-2)
        y_c_repeated = K.repeat_elements(y_c_repeated, rep=n_r, axis=-2)
        h_cond_prior_s = Dense(latent_dim_s, activation="relu")(y_s_repeated)
        h_cond_prior_c = Dense(latent_dim_c, activation="relu")(y_c_repeated)
        cond_prior_s_params = ls_s_prior.get_params(h_cond_prior_s)
        cond_prior_c_params = ls_c_prior.get_params(h_cond_prior_c)

        # SETUP LOSSES
        # reconstruction loss
        # reconstr_loss = self.reconstr_loss_function(x_in, x_out)  # defined within loss functions to avoid error

        # regular KL losses per latent space
        kl_r = K.sum(ls_r.kl_loss(z_r_params), axis=1)
        kl_s = K.sum(ls_s.kl_loss(z_s_params), axis=1)
        kl_c = K.sum(ls_c.kl_loss(z_c_params), axis=1)
        kl_full = kl_r + kl_s + kl_c

        # classification losses
        # y_in's have shape (..., n_classes), but y_pred's have shape (..., n_r, n_classes), so repeat n_r times
        # then sum over those n_r values
        clf_loss_s = tf.reduce_sum(categorical_crossentropy(y_s_repeated, y_s_pred), axis=-1)
        clf_loss_c = tf.reduce_sum(categorical_crossentropy(y_c_repeated, y_c_pred), axis=-1)

        # KL posterior to prior losses
        q_zs_xs = tfd.Normal(*z_s_params)  # approximate posterior
        q_zs_ys = tfd.Normal(*cond_prior_s_params)  # conditional prior
        kl_posterior_prior_s = tf.reduce_sum(q_zs_xs.kl_divergence(q_zs_ys), axis=(1, 2))
        q_zc_xc = tfd.Normal(*z_c_params)  # approximate posterior
        q_zc_yc = tfd.Normal(*cond_prior_c_params)  # conditional prior
        kl_posterior_prior_c = tf.reduce_sum(q_zc_xc.kl_divergence(q_zc_yc), axis=(1, 2))

        # averaging losses
        dist_to_avg_loss_r = tf.reduce_sum(ls_r.distance(z_r_sample, z_r_sample_avg), axis=1)
        dist_to_avg_loss_s = tf.reduce_sum(ls_s.distance(z_s_sample, z_s_sample_avg), axis=1)
        dist_to_avg_loss_c = tf.reduce_sum(ls_c.distance(z_c_sample, z_c_sample_avg), axis=1)
        dist_to_avg_loss_full_weighted = weight_dist_to_avg_r * dist_to_avg_loss_r\
                                         + weight_dist_to_avg_s * dist_to_avg_loss_s\
                                         + weight_dist_to_avg_c * dist_to_avg_loss_c

        # TODO: DEBUG, remove
        print("kl_full shape", kl_full.shape)
        print("clf_los_s shape", clf_loss_s.shape)
        print("clf_los_c shape", clf_loss_c.shape)
        print("kl_posterior_prior_s shape", kl_posterior_prior_s.shape)
        print("kl_posterior_prior_c shape", kl_posterior_prior_c.shape)
        print("dist_to_avg_loss_full_weighted shape", dist_to_avg_loss_full_weighted.shape)

        # SETUP FULL LOSS FUNCTIONS
        def full_loss_s(fake_true, fake_pred):
            return self.reconstr_loss_function(x_in, x_out) + weight_kl * kl_full + weight_clf * clf_loss_s\
                   + weight_kl_posterior_prior * kl_posterior_prior_s + dist_to_avg_loss_full_weighted
        full_loss_s.__name__ = "loss_s"

        def full_loss_c(fake_true, fake_pred):
            return self.reconstr_loss_function(x_in, x_out) + weight_kl * kl_full + weight_clf * clf_loss_c\
                   + weight_kl_posterior_prior * kl_posterior_prior_c + dist_to_avg_loss_full_weighted
        full_loss_s.__name__ = "loss_s"

        # SETUP FULL MODELS
        self.full_model_s = Model([x_in, y_s_in], [x_out, y_s_pred])
        self.full_model_s.compile(loss=full_loss_s, optimizer="adam", experimental_run_tf_function=False)
        print("SHAPE MODEL SUMMARY:")
        self.full_model_s.summary()

        self.full_model_c = Model([x_in, y_c_in], [x_out, y_c_pred])
        self.full_model_c.compile(loss=full_loss_c, optimizer="adam", experimental_run_tf_function=False)
        print("CULTURE MODEL SUMMARY:")
        self.full_model_c.summary()

        self.full_encoder_mu = Model(x_in, [z_r_params[1], z_s_params[0], z_c_params[0]])  # outputs mu's for r, s, c
        self.full_encoder_avg = Model(x_in, [z_r_sample_avg, z_s_sample_avg, z_c_sample_avg])

    def train_supervised(self, x_s_train, x_s_valid, y_s_train, y_s_valid,
                         x_c_train, x_c_valid, y_c_train, y_c_valid,
                         epochs, batch_size, callback_list=None):
        if x_s_valid is None or y_s_valid is None:
            validation_s = None
        else:
            validation_s = ([x_s_valid, y_s_valid], [x_s_valid, y_s_valid])
        if x_c_valid is None or y_c_valid is None:
            validation_c = None
        else:
            validation_c = ([x_c_valid, y_c_valid], [x_c_valid, y_c_valid])
        for epoch in range(epochs):
            print(f"Epoch {epoch+1} of {epochs}")
            print("Training on shapes")
            self.full_model_s.fit([x_s_train, y_s_train], [x_s_train, y_s_train], validation_data=validation_s,
                                  batch_size=batch_size, epochs=2*epoch + 1, initial_epoch=2*epoch,
                                  callbacks=callback_list)
            print("Training on cultures")
            self.full_model_c.fit([x_c_train, y_c_train], [x_c_train, y_c_train], validation_data=validation_c,
                                  batch_size=batch_size, epochs=2*epoch + 2, initial_epoch=2*epoch + 1,
                                  callbacks=callback_list)

    @property
    def architectures_function(self):
        # Select the appropriate architecture
        if self.enc_dec_architecture == "vgg":
            architectures_function = architectures.encoder_decoder_vgglike_2d
        elif self.enc_dec_architecture == "dis_lib":
            architectures_function = architectures.encoder_decoder_dislib_2d
        elif self.enc_dec_architecture == "dense":
            architectures_function = architectures.encoder_decoder_dense
        else:
            raise Exception()
        return architectures_function
