import numpy as np
import jax
from jax import numpy as jnp
import flax
from functools import partial
from jax import random, jit
from copy import deepcopy
from jax.scipy.stats import multivariate_normal
import itertools
# from ipdb import set_trace as db
import os
from abc import ABC, abstractmethod
from jax import numpy as jnp, random
from jax.scipy.stats import norm, logistic
import pickle as pkl
from tqdm import trange, tqdm
import numpy as np
from collections import defaultdict
from .refine_model import inner_refine_VAE, outer_refine_VAE, relu, FullyConnectedNetwork, clip_grad_norm



class VAE(flax.nn.Module):

    def apply(self, x, z1, z2, epsilon, encoder_layer_sizes, decoder_layer_sizes, latents=None, sampling=False, tunable_decoder_var = False):
        if sampling:
            mu = 0
            logvar_e = 0
        else:

            enc_out = FullyConnectedNetwork(x, layer_sizes=encoder_layer_sizes, name="Encoder")

            mu = enc_out
            epsilon_p = self.param('epsilon_p', (z1.shape[-1],), jax.nn.initializers.ones)
            if tunable_decoder_var:
                epsilon = self.param('epsilon', (1,), jax.nn.initializers.ones) * epsilon
            logvar_e = epsilon_p
        stdevs = jnp.exp(logvar_e / 2)
        samples = mu + stdevs * z1
        if latents is not None:
            print ("set latents")
            samples = latents


        x_hat = FullyConnectedNetwork(samples, layer_sizes=decoder_layer_sizes, name="Decoder")
        stdev = jnp.exp(epsilon / 2.)
        noise = z2 * stdev
        x_hat = x_hat  + noise
        return x_hat, mu, logvar_e, epsilon


    @staticmethod
    @jax.jit
    def train_step(optimizer, batch, z1, z2):
        def loss_fn(model):

            x_hat, mu, logvar_e, epsilon = model(batch, z1, z2)
            Dkl = -0.5 * jnp.sum((1 + logvar_e - jnp.exp(logvar_e) - jnp.square(mu)), axis=-1)
            var_d = jnp.exp(epsilon)
            mse = (0.5 * jnp.square(x_hat - batch)/var_d + 0.5 * (jnp.log(2. * jnp.pi) + epsilon)).sum(axis=-1)
            loss = Dkl + mse
            return loss.mean()
        vae_loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)
        grad = clip_grad_norm(grad, 1000)
        optimizer = optimizer.apply_gradient(grad)
        return optimizer, optimizer.target, vae_loss

    @staticmethod
    @jax.jit
    def loss(model, batch, z1, z2):
        x_hat, mu, logvar_e, epsilon = model(batch, z1, z2)
        Dkl = -0.5 * jnp.sum((1 + logvar_e - jnp.exp(logvar_e) - jnp.square(mu)), axis=-1)
        var_d = jnp.exp(epsilon)
        mse = (0.5 * jnp.square(x_hat - batch) / var_d + 0.5 * (jnp.log(2. * jnp.pi) + epsilon)).sum(axis=-1)
        loss = Dkl + mse
        return loss.mean(), Dkl.mean(), jnp.square(x_hat - batch).sum(axis=-1).mean(), logvar_e, epsilon

class VAEModel(ABC):
    def __init__(self,
                 dirname,
                 num_epoch,
                 batch_size,
                 learning_rate,
                 layer_sizes,
                 encoder_layer_sizes,
                 epsilon,
                 tqdm,
                 dataset,
                 latent_dimension,
                 tunable_decoder_var = False,
                 n_save = 20,
                 data_size = 8,
                 load_model = "",
                 refine_type = "",
                 seed=0):

        super().__init__()
        self.epsilon = epsilon
        self.current_epsilon = epsilon
        self.latent_dimension = latent_dimension
        self.dirname = dirname
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.tqdm = tqdm
        self.num_epoch = num_epoch
        self.n_save = n_save
        self.dataset = dataset
        self.stats = defaultdict(list)
        self.freq_stats = 5
        self.key = random.PRNGKey(seed)
        data_size = self.dataset.dimension if not self.dataset == None else data_size

        self.encoder_layer_sizes = ([int(size) for size in encoder_layer_sizes.split('|')] \
                                   if encoder_layer_sizes != "" else []) + [latent_dimension]
        self.decoder_layer_sizes = ([int(size) for size in layer_sizes.split('|')] \
                                   if layer_sizes != "" else []) + [data_size]

        vae_key, self.key = random.split(self.key)
        vae_module = VAE.partial(epsilon = epsilon, encoder_layer_sizes=self.encoder_layer_sizes, \
                                 decoder_layer_sizes=self.decoder_layer_sizes, tunable_decoder_var=tunable_decoder_var)

        if refine_type == "inner":
            vae_module = inner_refine_VAE.partial(epsilon = epsilon, encoder_layer_sizes=self.encoder_layer_sizes, \
                                 decoder_layer_sizes=self.decoder_layer_sizes, tunable_decoder_var=tunable_decoder_var)
        elif refine_type == "outer":
            vae_module = outer_refine_VAE.partial(epsilon = epsilon, encoder_layer_sizes=self.encoder_layer_sizes, \
                                 decoder_layer_sizes=self.decoder_layer_sizes, tunable_decoder_var=tunable_decoder_var)
        _, initial_params = vae_module.init_by_shape(vae_key, [(data_size,), (latent_dimension,), (data_size,)])

        if load_model:
            initial_params = self.load(load_model, random_init = initial_params)
            self.epsilon = initial_params['epsilon'] * epsilon
            print ("initial epsilon: " + str(self.epsilon))

        self.model = flax.nn.Model(vae_module, initial_params)

        if load_model:
            print ("special optimizer")
            deco = flax.optim.ModelParamTraversal(lambda path, _: 'Decoder' in path)
            enco = flax.optim.ModelParamTraversal(lambda path, _: 'Encoder' in path)
            var = flax.optim.ModelParamTraversal(lambda path, _: 'epsilon_p' in path)
            eps = flax.optim.ModelParamTraversal(lambda path, _: 'epsilon' in path)
            if refine_type == "whole":
                deco_opt, enco_opt, var_opt, eps_opt = self.whole_refine_opt(learning_rate)
                opt_def = flax.optim.MultiOptimizer((deco, deco_opt), (enco, enco_opt), (var, var_opt), (eps, eps_opt))
            else:
                ldeco = flax.optim.ModelParamTraversal(lambda path, _: 'Last_Enc' in path)
                lenco = flax.optim.ModelParamTraversal(lambda path, _: 'First_Dec' in path)
                if refine_type == "inner":
                    deco_opt, enco_opt, ldeco_opt, lenco_opt, var_opt, eps_opt = self.inner_refine_opt(learning_rate)
                elif refine_type == "outer":
                    deco_opt, enco_opt, ldeco_opt, lenco_opt, var_opt, eps_opt = self.outer_refine_opt(learning_rate)
                opt_def = flax.optim.MultiOptimizer((deco, deco_opt), (enco, enco_opt), (var, var_opt), (eps, eps_opt), \
                                                    (ldeco, ldeco_opt), (lenco, lenco_opt))
            self.optimizer = opt_def.create(self.model)
        else:
            self.optimizer = flax.optim.Adam(learning_rate=self.learning_rate).create(self.model)



    def train_one_batch(self, batch):
        batch = batch.reshape((batch.shape[0], -1))
        z = self.sample_latent(batch.shape[0])
        z1 = z[..., :self.latent_dimension]
        z2 = z[..., self.latent_dimension:]
        self.optimizer, self.model, vae_loss = VAE.train_step(self.optimizer, batch, z1, z2)


    def compute_model_stats(self, real_batch):
        batch_size = real_batch.shape[0]
        z = self.sample_latent(batch_size)
        z1 = z[..., :self.latent_dimension]
        z2 = z[..., self.latent_dimension:]
        vae_loss, dkl, mse, logvar_e, epsilon = VAE.loss(self.model, real_batch, z1, z2)
        self.epsilon = epsilon
        data = {"VAE Loss": vae_loss, "KL divergence": dkl, "mse": mse, "epsilon": epsilon} # "Decoder Variance": epslon, "Encoder Varian}
        return data


    def sample_batch(self, batch_size, latents=None):
        sample_batch_key, self.key = random.split(self.key)
        z = self.sample_latent(batch_size)
        z1 = z[..., :self.latent_dimension]
        z2 = z[..., self.latent_dimension:]
        sampling_model = jax.jit(partial(self.model, sampling=True, epsilon=self.epsilon))
        x_hat, mu, logvar_e, epsilon = sampling_model(None, z1, z2, latents=latents)
        return x_hat, z


    def train(self):
        eval_batch_key, self.key = random.split(self.key)
        min_mse = jnp.inf
        for self.epochnum in range(self.num_epoch):
            for i in trange(self.dataset.num_batch):
                batch = self.dataset.get_batch()
                self.train_one_batch(batch)
            self.dataset.get_test_batch()
            stats = self.compute_model_stats(batch)
            self.write_stats(stats)
            
            if self.epochnum % 1000 == 0 or self.epochnum == self.num_epoch-1:
                self.save(self.epochnum)
            self.dataset.randomize()

    def encode(self):
        latents = []
        eval_batch_key, self.key = random.split(self.key)
        for i in trange(self.dataset.num_batch):
            batch = self.dataset.get_batch()
            batch_size = batch.shape[0]
            sample_batch_key, self.key = random.split(self.key)
            z = self.sample_latent(batch_size)
            z1 = z[..., :self.latent_dimension]
            z2 = z[..., self.latent_dimension:]
            x_hat, mu, logvar_e, epsilon = self.model(batch, z1, z2)
            stdevs = jnp.exp(logvar_e / 2)
            samples = mu + stdevs * z1
            latents.append(samples)
        latents = np.concatenate(latents, axis=0)
        return latents

    def write_stats(self, stats):
        message = f"Epoch | {self.epochnum}"
        for stat, val in stats.items():
            self.stats[stat].append(val)
            try:
                val = float(val)
            except Exception:
                self.stats[stat].append(val)
                continue;
            message = message + f" | {stat} | {val:.3f}"
        tqdm.write(message)

    def sample_latent(self, batch_size):
        key, self.key = random.split(self.key)
        output = random.normal(key, shape=(batch_size, self.latent_dimension + self.dataset.dimension))
        return output

    def save(self, epoch):
        state_dict = flax.serialization.to_state_dict(self.optimizer)
        model_fn = os.path.join(self.dirname, "model."+str(epoch)+".pkl")
        with open(model_fn, "wb") as f:
            pkl.dump(state_dict, f)
        fn = os.path.join(self.dirname, "losses" )
        stats = deepcopy(self.stats)
        np.savez(fn, **stats)

    def load(self, model="", random_init = None, refine_type=None):
        print ("Load Model")
        if len(model.split("/")) > 1:
            loaded_params = jax.numpy.load(model, allow_pickle=True)['target']['params']
            params = random_init
            if random_init is not None:
                key, self.key = random.split(self.key)
                params["Encoder"] = loaded_params["Encoder"]
                params["Decoder"] = loaded_params["Decoder"]
                params["epsilon_p"] = loaded_params["epsilon_p"]
                params["epsilon"] = loaded_params["epsilon"]
        else:
            print ("Load Full Model")
            params = jax.numpy.load(self.dirname + "/" + model, allow_pickle=True)['target']['params']
        return params

    def inner_refine_opt(self, learning_rate):
        deco_opt = flax.optim.Adam(learning_rate=0)
        enco_opt = flax.optim.Adam(learning_rate=0)
        ldeco_opt = flax.optim.Adam(learning_rate=learning_rate)
        lenco_opt = flax.optim.Adam(learning_rate=learning_rate)
        var_opt = flax.optim.Adam(learning_rate=learning_rate)
        eps_opt = flax.optim.Adam(learning_rate=0)
        return deco_opt, enco_opt, ldeco_opt, lenco_opt, var_opt, eps_opt

    def outer_refine_opt(self, learning_rate):
        deco_opt = flax.optim.Adam(learning_rate=0)
        enco_opt = flax.optim.Adam(learning_rate=0)
        ldeco_opt = flax.optim.Adam(learning_rate=learning_rate)
        lenco_opt = flax.optim.Adam(learning_rate=learning_rate)
        var_opt = flax.optim.Adam(learning_rate=0)
        eps_opt = flax.optim.Adam(learning_rate=0)
        return deco_opt, enco_opt, ldeco_opt, lenco_opt, var_opt, eps_opt

    def whole_refine_opt(self, learning_rate):
        deco_opt = flax.optim.Adam(learning_rate=learning_rate)
        enco_opt = flax.optim.Adam(learning_rate=learning_rate)
        var_opt = flax.optim.Adam(learning_rate=learning_rate)
        eps_opt = flax.optim.Adam(learning_rate=0)
        return deco_opt, enco_opt, var_opt, eps_opt

