import numpy as np
import jax
from jax import numpy as jnp
import flax


def relu(x):
    return jnp.maximum(x, 0)

def clip_grad_norm(grad, max_norm):
    norm = jnp.linalg.norm(jax.tree_util.tree_leaves(jax.tree_map(jnp.linalg.norm, grad)))
    clip = lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6))
    return jax.tree_util.tree_map(clip, grad)


class FullyConnectedNetwork(flax.nn.Module):
    def apply(self, x, layer_sizes):
        for i, size in enumerate(layer_sizes):
            name = self.get_layer_name(i)
            x = flax.nn.Dense(x, features=size, name=name)
            if i + 1 < len(layer_sizes):
                x = relu(x)
        return x

    def get_layer_name(self, i):
        return f"FC{i}"

    @staticmethod
    def train_step(**args):
        raise NotImplementedError()

    @staticmethod
    @jax.jit
    def evaluate(model, data):
        return model(data)

class inner_refine_VAE(flax.nn.Module):

    def apply(self, x, z1, z2, epsilon, encoder_layer_sizes, decoder_layer_sizes, latents=None, sampling=False, tunable_decoder_var = False):
        if sampling:
            mu = 0
            logvar_e = 0
        else:

            enc = FullyConnectedNetwork(x, layer_sizes=encoder_layer_sizes, name="Encoder")
            enc_out = FullyConnectedNetwork(relu(enc), layer_sizes=encoder_layer_sizes[-1:], name="Last_Enc")
            mu = enc_out
            epsilon_p = self.param('epsilon_p', (z1.shape[-1],), jax.nn.initializers.ones)
            if tunable_decoder_var:
                epsilon = self.param('epsilon', (1,), jax.nn.initializers.ones) * epsilon
            logvar_e = epsilon_p
        stdevs = jnp.exp(logvar_e / 2)
        samples = mu + stdevs * z1
        if latents is not None:
            print ("set latents")
            samples = latents

        dec = FullyConnectedNetwork(samples, layer_sizes=encoder_layer_sizes[-1:], name="First_Dec")
        x_hat = FullyConnectedNetwork(relu(dec), layer_sizes=decoder_layer_sizes, name="Decoder")
        stdev = jnp.exp(epsilon / 2.)
        noise = z2 * stdev
        x_hat = x_hat  + noise
        return x_hat, mu, logvar_e, epsilon


    @staticmethod
    @jax.jit
    def train_step(optimizer, batch, z1, z2):
        def loss_fn(model):

            x_hat, mu, logvar_e, epsilon = model(batch, z1, z2)
            Dkl = -0.5 * jnp.sum((1 + logvar_e - jnp.exp(logvar_e) - jnp.square(mu)), axis=-1)
            var_d = jnp.exp(epsilon)
            mse = (0.5 * jnp.square(x_hat - batch)/var_d + 0.5 * (jnp.log(2. * jnp.pi) + epsilon)).sum(axis=-1)
            loss = Dkl + mse
            return loss.mean()
        vae_loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
        grad = clip_grad_norm(grad, 1000)
        optimizer = optimizer.apply_gradient(grad)
        return optimizer, optimizer.target, vae_loss

    @staticmethod
    @jax.jit
    def loss(model, batch, z1, z2):
        x_hat, mu, logvar_e, epsilon = model(batch, z1, z2)
        Dkl = -0.5 * jnp.sum((1 + logvar_e - jnp.exp(logvar_e) - jnp.square(mu)), axis=-1)
        var_d = jnp.exp(epsilon)
        mse = (0.5 * jnp.square(x_hat - batch) / var_d + 0.5 * (jnp.log(2. * jnp.pi) + epsilon)).sum(axis=-1)
        loss = Dkl + mse
        return loss.mean(), Dkl.mean(), jnp.square(x_hat - batch).sum(axis=-1).mean(), logvar_e, epsilon

class outer_refine_VAE(flax.nn.Module):

    def apply(self, x, z1, z2, epsilon, encoder_layer_sizes, decoder_layer_sizes, latents=None, sampling=False, tunable_decoder_var = False):
        if sampling:
            mu = 0
            logvar_e = 0
        else:
            enc = FullyConnectedNetwork(x, layer_sizes=encoder_layer_sizes[-1:], name="Last_Enc")
            enc_out = FullyConnectedNetwork(relu(enc), layer_sizes=encoder_layer_sizes, name="Encoder")
            mu = enc_out
            epsilon_p = self.param('epsilon_p', (z1.shape[-1],), jax.nn.initializers.ones)
            if tunable_decoder_var:
                epsilon = self.param('epsilon', (1,), jax.nn.initializers.ones) * epsilon
            logvar_e = epsilon_p
        stdevs = jnp.exp(logvar_e / 2)
        samples = mu + stdevs * z1
        if latents is not None:
            print ("set latents")
            samples = latents

        dec = FullyConnectedNetwork(samples, layer_sizes=decoder_layer_sizes, name="Decoder")
        x_hat = FullyConnectedNetwork(relu(dec), layer_sizes=encoder_layer_sizes[-1:], name="First_Dec")
        stdev = jnp.exp(epsilon / 2.)
        noise = z2 * stdev
        x_hat = x_hat  + noise
        return x_hat, mu, logvar_e, epsilon


    @staticmethod
    @jax.jit
    def train_step(optimizer, batch, z1, z2):
        def loss_fn(model):

            x_hat, mu, logvar_e, epsilon = model(batch, z1, z2)
            Dkl = -0.5 * jnp.sum((1 + logvar_e - jnp.exp(logvar_e) - jnp.square(mu)), axis=-1)
            var_d = jnp.exp(epsilon)
            mse = (0.5 * jnp.square(x_hat - batch)/var_d + 0.5 * (jnp.log(2. * jnp.pi) + epsilon)).sum(axis=-1)
            loss = Dkl + mse
            return loss.mean()
        vae_loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
        grad = clip_grad_norm(grad, 1000)
        optimizer = optimizer.apply_gradient(grad)
        return optimizer, optimizer.target, vae_loss

    @staticmethod
    @jax.jit
    def loss(model, batch, z1, z2):
        x_hat, mu, logvar_e, epsilon = model(batch, z1, z2)
        Dkl = -0.5 * jnp.sum((1 + logvar_e - jnp.exp(logvar_e) - jnp.square(mu)), axis=-1)
        var_d = jnp.exp(epsilon)
        mse = (0.5 * jnp.square(x_hat - batch) / var_d + 0.5 * (jnp.log(2. * jnp.pi) + epsilon)).sum(axis=-1)
        loss = Dkl + mse
        return loss.mean(), Dkl.mean(), jnp.square(x_hat - batch).sum(axis=-1).mean(), logvar_e, epsilon
