
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import flax.nnx as nnx
import optax

from jaxmodels_nnx import build_model
from reps.rep_model import RepModel

from utils.printarr import printarr

sg = lambda x: jax.tree.map(jax.lax.stop_gradient, x)

# class iVAE(RepModel):
#     def __init__(self, config, rngs):
#         super().__init__(config, rngs)
#         self.config = config
#         self.rngs = rngs
#         self.max_var = np.log(10)
#         self.min_var = np.log(0.01)
#         self.latent_dim = self.config.latent_dim
#         self.n_actions = self.config.n_actions
#         self.vars_per_factor = self.config.get('vars_per_factor', 1)
#         self.n_factors = self.latent_dim // self.vars_per_factor

#         # encoder 
#         if self.config.is_pixel:
#             self.config.encoder_pixel.input_shape = self.input_shape
#             self.config.decoder_pixel.output_shape = self.input_shape
#             self.config.decoder_pixel.input_dim = self.config.latent_dim
#             self.embed_obs = build_model(self.config.encoder_pixel, rngs)
#             self.decoder = build_model(self.config.decoder_pixel, rngs)
#         else:
#             self.config.encoder.input_dim = self.input_shape
#             self.config.decoder.output_dim = self.input_shape
#             self.config.decoder.input_dim = self.config.latent_dim
#             self.embed_obs = build_model(self.config.encoder, rngs)
#             self.decoder = build_model(self.config.decoder, rngs)

#         self.inference = build_model(self.config.inference, rngs)
#         self.noise_std = np.sqrt(0.01)

        
#     def forward(
#             self,
#             obs,
#             actions,
#             dones,
#             rewards,
#             rng,
#             states
#     ):
#         embed_obs = self.embed_obs(obs)
#         inference_input = jnp.concatenate([embed_obs[:, :-1], nnx.one_hot(actions, self.n_actions), embed_obs[:, 1:]], axis=-1)
        


class DMS(RepModel):
    def __init__(self, config, rngs):
        super().__init__(config)
        self.config = config
        self.rngs = rngs
        self.max_var = np.log(10)
        self.min_var = np.log(0.01)
        self.latent_dim = self.config.latent_dim
        self.n_actions = self.config.n_actions
        self.vars_per_factor = 1
        self.n_factors = self.latent_dim // self.vars_per_factor


        # encoder / decoder
        if self.config.is_pixel:
            self.config.encoder_pixel.input_shape = self.config.obs_dim
            self.config.decoder_pixel.output_shape = self.config.obs_dim
            self.config.decoder_pixel.input_dim = self.config.latent_dim
            self.encoder = build_model(self.config.encoder_pixel, rngs)
            self.decoder = build_model(self.config.decoder_pixel, rngs)
        else:
            self.config.encoder.input_dim = self.config.obs_dim[0]
            self.config.encoder.output_dim = self.config.latent_dim * 2
            self.config.decoder.output_dim = self.config.obs_dim[0]
            self.config.decoder.input_dim = self.config.latent_dim
            self.encoder = build_model(self.config.encoder, rngs)
            self.decoder = build_model(self.config.decoder, rngs)
        
        self.decoder_logstd = nnx.Param(jnp.ones((1,))) # isotropic gaussian

        # transition
        self.transition_logstd = nnx.Param(jax.random.normal(rngs(), (self.latent_dim,)))
        self.config.transition.input_dim = self.latent_dim + self.n_actions
        self.config.transition.output_dim = 1
        self.transition_mean = nnx.vmap(
            build_model,
            in_axes=(None, 0, 0)
        )(config.transition, rngs, jnp.arange(self.latent_dim))
        initializer = jax.nn.initializers.lecun_uniform()
        self.g_a = nnx.Param(initializer(rngs(), (self.latent_dim, self.n_actions), dtype=jnp.float32))
        self.g_z = nnx.Param(initializer(rngs(), (self.latent_dim, self.latent_dim), dtype=jnp.float32))

    def sample_gumbel_mask(self, rng, gumbel_logits, hard=True):
        temp = self.config.params.get('gumbel_temp', 1.)
        sample = (gumbel_logits + sg(jax.random.logistic(rng, gumbel_logits.value.shape))) / temp
        sample_soft = jax.nn.sigmoid(sample)
        sample = sample_soft
        if hard:
            sample = (sample_soft > 0.5).astype(jnp.float32)
            sample = sg(sample) - sg(sample_soft) + sample_soft
        return sample

    def encode(self, obs, states=None):
        if self.config.get('use_ground_truth_states', False):
            if states is None:
                raise ValueError('States must be provided if use_ground_truth_states is True')
            return states
        else:
            return self.encoder(obs)[..., :self.latent_dim]

    def forward(self, obs, actions, dones, rewards, rng, states):
        rng, rng_preprocess, rng_embed, rng_transition, rng_ga, rng_gz, rng_recons = jax.random.split(rng, 7)
        obs = self.preprocess(obs, rng_preprocess)
        embed_obs = self.encoder(obs)
        embed_obs_mean, embed_obs_logstd = jnp.split(embed_obs, 2, axis=-1)

        # limit the variance
        embed_obs_logstd = self.max_var - jax.nn.softplus(self.max_var - embed_obs_logstd)
        embed_obs_logstd = self.min_var + jax.nn.softplus(embed_obs_logstd - self.min_var)

        # sample
        embed_obs = embed_obs_mean + jnp.exp(embed_obs_logstd) * jax.random.normal(rng_embed, embed_obs_mean.shape)

        z = embed_obs[:, :-1]
        next_z = embed_obs[:, 1:]
        next_z_dist = (embed_obs_mean[:, 1:], embed_obs_logstd[:, 1:])


        # sample masks
        batch = z.shape[0] * z.shape[1]
        mask_z = nnx.vmap(self.sample_gumbel_mask, in_axes=(0, None, None))(jax.random.split(rng_gz, batch), self.g_z, True)
        mask_z = mask_z.reshape(*z.shape[:2], *mask_z.shape[-2:])
        mask_a = nnx.vmap(self.sample_gumbel_mask, in_axes=(0, None, None))(jax.random.split(rng_ga, batch), self.g_a, True)
        mask_a = mask_a.reshape(*z.shape[:2], *mask_a.shape[-2:])

        # transition
        _actions = nnx.one_hot(actions, self.n_actions)
        def _transition(transition_fn, mask_z, mask_a):
            transition_input = jnp.concatenate([z * mask_z, _actions * mask_a], axis=-1)
            transition_mean = transition_fn(transition_input)
            return transition_mean
        
        transition_mean = nnx.vmap(_transition, in_axes=(0, 2, 2))(self.transition_mean, mask_z, mask_a)
        transition_logstd = self.transition_logstd.value
        # limit
        transition_logstd = self.max_var - jax.nn.softplus(self.max_var - transition_logstd)
        transition_logstd = self.min_var + jax.nn.softplus(transition_logstd - self.min_var)
        transition_dist = (transition_mean, transition_logstd)

        recons_mean = self.decoder(embed_obs)
        # limit the variance
        recons_logstd = self.max_var - jax.nn.softplus(self.max_var - self.decoder_logstd.value)
        recons_logstd = self.min_var + jax.nn.softplus(recons_logstd - self.min_var)
        recons_std = jnp.exp(recons_logstd)

        return (z, actions, next_z), (transition_dist, next_z_dist, mask_z, mask_a, recons_mean, recons_std, obs, rng_recons)


    def loss_fn(self, prioritized=False):
        def _loss(
                z,
                actions,
                next_z,
                transition_dist,
                next_z_dist,
                mask_z,
                mask_a,
                recons_mean,
                recons_std,
                obs,
                rng,
                importance_weights,
            ):

            # recons loss (NLL)
            recons_loss = -(recons_mean - obs) ** 2 / (2 * recons_std ** 2) - jnp.log(recons_std) / 2
            recons_loss = jnp.sum(recons_loss, axis=-1)
            recons_loss = -jnp.mean(recons_loss)

            # transition loss (kl between standard normal and transition distribution)
            def kl_gaussians(mean1, logstd1, mean2, logstd2):
                return -0.5 * jnp.sum(1 + logstd2 - logstd1 - (mean1 - mean2) ** 2 / jnp.exp(logstd2) - jnp.exp(logstd1 - logstd2), axis=-1)
            transition_mean, transition_logstd = transition_dist
            next_z_mean, next_z_logstd = next_z_dist
            kl_loss = kl_gaussians(next_z_mean, next_z_logstd, transition_mean, transition_logstd)
            kl_loss = jnp.mean(kl_loss)
            # sparsity graph losses
            ga_loss = mask_a.sum(-1).sum(-1).mean()
            gt_loss = mask_z.sum(-1).sum(-1).mean()

            elbo = (recons_loss + kl_loss)
            # total loss
            loss = self.config.params.elbo_const*elbo.mean() + \
              self.config.params.l2_reg_const*(next_z**2).sum(-1).mean() + \
              self.config.params.g_action_const*ga_loss + \
              self.config.params.g_time_const*gt_loss
            
            logs = { 
                'scalars': {
                    'norm': (z ** 2).sum(-1).mean(),
                    'kl_loss': -kl_loss,
                    'recons_loss': recons_loss,
                    'elbo': elbo,
                    'ga_loss': ga_loss,
                    'gt_loss': gt_loss,
                    'loss': loss
                }
            }
            return loss, logs

        return _loss, {}
