import time
import numpy as np
import tensorflow as tf
import tensorflow_probability.python.distributions as tfd

from keras.layers import *
from utils import Logger
from losses import elbo


class Dave(tf.keras.Model):

    def __init__(self, latent_dim, lr, grad, anneal_rate, samples, val_samples, batch_size, temp, h1=384, h2=256, img_size=28):
        super(Dave, self).__init__()
        self.dim = latent_dim
        self.lr = lr
        self.grad = grad
        self.anneal_rate = anneal_rate
        self.samples = samples
        self.val_samples = val_samples
        self.batch_size = batch_size
        self.temp = temp
        self.img_size = img_size

        self.encoder = tf.keras.Sequential()
        self.encoder.add(Dense(h1, activation="relu"))
        self.encoder.add(Dense(h2, activation="relu"))
        self.encoder.add(Dense(self.dim, activation='linear'))

        self.decoder = tf.keras.Sequential()
        self.decoder.add(Dense(h2, activation="relu"))
        self.decoder.add(Dense(h1, activation="relu"))
        self.decoder.add(Dense(self.img_size ** 2, activation='linear'))

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr)
        if "icr" in self.grad:
            self.catsch_multiplier()
        self.logger = Logger()

    def catsch_multiplier(self):
        mult = np.zeros([2, self.dim, self.samples, self.batch_size, self.dim], dtype=np.int64)
        replacement = np.zeros([2, self.dim, self.samples, self.batch_size, self.dim], dtype=np.int64)
        for i in range(self.dim):
            mult[1, i, :, :, i] = 1
            mult[0, i, :, :, i] = 1
            replacement[1, i, :, :, i] = 1
            replacement[0, i, :, :, i] = 0
        self.catsch_mult = tf.constant(mult)
        self.catsch_replacement = tf.constant(replacement) 

    @tf.function
    def sample(self, z, sample_size, soft=False):
        probs = tf.sigmoid(z)
        if soft:       
            soft_probs = tfd.RelaxedBernoulli(self.temp, logits=z).sample(sample_size)
            hard_probs = tf.cast(tf.round(soft_probs), dtype=tf.float32)
            diff_out = tf.stop_gradient(hard_probs - soft_probs) + soft_probs
            return diff_out
        return tfd.Bernoulli(probs=probs).sample(sample_size)

    @tf.function
    def call(self, x, sample_size, soft=False):
        z = self.encoder(x)
        samples = self.sample(z, sample_size, soft)
        x_hat = self.decoder(samples)
        return z, x_hat

    @tf.function
    def indecater_grads(self, x):
        with tf.GradientTape() as tape:
            theta = self.encoder(x)
            samples = tf.cast(self.sample(theta, self.samples), dtype=tf.int64)
            x_hat = self.decoder(samples)
            b_loss = tf.reduce_mean(elbo(x, x_hat, theta), axis=0)
            loss = tf.reduce_mean(b_loss)

            outer_samples = tf.stack([samples] * self.dim, axis=0)
            outer_samples = tf.stack([outer_samples] * 2, axis=0)
            outer_samples = outer_samples * (1 - self.catsch_mult) + self.catsch_replacement
            outer_samples_1 = outer_samples[1]
            outer_samples_0 = outer_samples[0]

            outer_loss_1 = elbo(x, self.decoder(outer_samples_1), theta)
            outer_loss_0 = elbo(x, self.decoder(outer_samples_0), theta)
            variable_loss_1 = tf.transpose(tf.reduce_mean(outer_loss_1, axis=1))
            variable_loss_0 = tf.transpose(tf.reduce_mean(outer_loss_0, axis=1))

            catsch_expression = tf.reduce_sum(tf.stop_gradient(variable_loss_1) * tf.math.sigmoid(theta), axis=1)
            catsch_expression += tf.reduce_sum(tf.stop_gradient(variable_loss_0) * (1 - tf.math.sigmoid(theta)), axis=1)
            catsch_expression = tf.reduce_mean(catsch_expression)
            catsch_grad = tape.gradient(loss + catsch_expression, self.trainable_variables)
        return catsch_grad, loss

    @tf.function
    def rloo_grads(self, x):
        with tf.GradientTape() as tape:
            theta = self.encoder(x)
            samples = tf.cast(self.sample(theta, self.samples), dtype=tf.int64)
            x_hat = self.decoder(samples)
            sample_probs = tf.where(samples == 1, tf.math.sigmoid(theta), 1 - tf.math.sigmoid(theta))
            sample_logps = tf.reduce_sum(tf.math.log(sample_probs), axis=-1)

            sample_loss = elbo(x, x_hat, theta)
            batch_loss = tf.reduce_mean(sample_loss, axis=0)
            loss = tf.reduce_mean(batch_loss)

            sample_rloo = tf.stop_gradient(sample_loss - batch_loss) * sample_logps
            rloo_expression = tf.reduce_sum(sample_rloo, axis=0) / (self.samples - 1)
            rloo_expression = tf.reduce_mean(rloo_expression)
            rloo_grad = tape.gradient(loss + rloo_expression, self.trainable_variables)
        return rloo_grad, loss

    @tf.function
    def gs_grads(self, x):
        with tf.GradientTape() as tape:
            theta = self.encoder(x)
            samples = self.sample(theta, self.samples, soft=True)
            x_hat = self.decoder(samples)
            loss = tf.reduce_mean(elbo(x, x_hat, theta), axis=0)
            loss = tf.reduce_mean(loss)
            gs_grad = tape.gradient(loss, self.trainable_variables)
            gs_grad = [tf.where(tf.math.is_nan(grad), tf.zeros_like(grad), grad) for grad in gs_grad]
        return gs_grad, loss

    def grads(self, x):
        if self.grad == 'icr':
            grad, loss = self.indecater_grads(x)
        elif self.grad == 'gs':
            grad, loss = self.gs_grads(x)
        elif self.grad == 'rloo':
            grad, loss = self.rloo_grads(x)
        return grad, loss

    def train(self, data, epochs, val_data=None, log_its=100):
        counter = 1
        acc_loss = 0
        var_grads = []
        prev_time = time.time()
        for epoch in range(epochs):
            self.temp *= tf.exp(-self.anneal_rate)
            self.temp = tf.clip_by_value(self.temp, 0.1, 1)
            for x in data:
                grad, loss = self.grads(x)
                acc_loss += loss
                var_grads.append(tf.reduce_mean([tf.math.reduce_std(g) ** 2 for g in grad]))
                self.optimizer.apply_gradients(zip(grad, self.trainable_variables))
                if counter % log_its == 0:
                    update_time = time.time() - prev_time
                    val_counter = 0
                    val_loss = 0
                    for val_x in val_data:
                        val_theta, val_x_hat = self.call(val_x, self.val_samples)
                        add_loss = tf.reduce_mean(tf.reduce_mean(elbo(val_x, val_x_hat, val_theta), axis=0))
                        val_loss += add_loss
                        val_counter += 1
                    print(
                        f"Epoch {epoch} iterations {counter}: {acc_loss / log_its}",
                        f"Validation loss: {val_loss / val_counter}",
                        f"Time (s): {update_time}",
                        f"Gradient variance: {np.mean(var_grads)}",
                        )
                    self.logger.log("training_loss", counter, acc_loss / log_its)
                    self.logger.log("validation_loss", counter, val_loss / val_counter)
                    self.logger.log("time", counter, update_time)
                    self.logger.log("gradient_variance", counter, np.mean(var_grads))
                    acc_loss = 0
                    var_grads = []
                    prev_time = time.time()
                counter += 1
