from functools import partial

import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax


from jaxmodels_nnx import build_model
from reps.rep_model import RepModel

from utils.printarr import printarr

sg = lambda x: jax.tree.map(jax.lax.stop_gradient, x)

class MarkovRepresentation(RepModel):
    def __init__(self, config, rngs):
        super().__init__(config, rngs)
        self.config.inverse.input_dim = self.config.latent_dim * 2
        self.config.inverse.output_dim = self.config.n_actions
        self.config.ratio.input_dim = self.config.latent_dim * 2
        self.config.ratio.output_dim = 1
        self.n_factors = self.config.latent_dim
        if self.config.is_pixel:
            self.config.encoder_pixel.input_shape = self.input_shape
        else:
            self.config.encoder.input_dim = self.input_shape
        self.encoder = build_model(self.config.encoder_pixel, rngs=rngs) if self.config.is_pixel else build_model(self.config.encoder, rngs=rngs)
        self.inverse = build_model(self.config.inverse, rngs=rngs)
        self.ratio = build_model(self.config.ratio, rngs=rngs)
        self.batch_size = self.config.ratio_batch_size
    
    def infer_actions(
            self,
            obs,
            states=None,
            rng=jax.random.PRNGKey(0)
    ):
        '''
            Infer action from single transition
            vmap it to work with batches.
            obs: (T, D)
            dones: (T,)
        '''
        z = self.encode(obs, states)
        next_z = z[1:]
        z = z[:-1]

        inverse_logits = self.inverse(jnp.concatenate([z, next_z], axis=-1))
        actions_probs = jax.nn.softmax(inverse_logits)
        return actions_probs

    def forward(
            self,
            obs,
            actions,
            rewards,
            dones,
            rng,
            states=None
        ):
        rng, rng_noise = jax.random.split(rng)
        obs = self.preprocess(obs, rng_noise)
        z = self.encoder(obs)
        next_z = z[:, 1:]
        z = z[:, :-1]
        inverse_logits = self.inverse(jnp.concatenate([z, next_z], axis=-1))
        # flatten
        batchdims, latent_dim = z.shape[:-1], z.shape[-1]
        _z = z.reshape(-1, latent_dim)
        _next_z = next_z.reshape(-1, latent_dim)
        _actions = actions.reshape(-1)
        dones = dones.reshape(-1)
        mask = 1-dones

        if len(z.shape) == 1:
            z = z[None]
            next_z = next_z[None]

        # subsample 
        rng_perm, rng_sample = jax.random.split(rng)
        
        z = z.reshape(-1, latent_dim)
        next_z = z.reshape(-1, latent_dim)
        dones = dones.reshape(-1)
        # batch_size = jnp.minimum(self.batch_size, dones.shape[0])
        # batch_size = self.batch_size

        # idx = jax.random.categorical(
        #     rng_sample,
        #     jnp.log((1-dones)/(1-dones).sum() + 1e-12), # don't sample done elements because z->z' is invalid
        #     axis=0,
        #     shape=(batch_size,)
        # )
        batch_size = _z.shape[0]
        idx = jnp.arange(_z.shape[0])

        _z = _z[idx]
        _next_z = _next_z[idx]
        _actions = _actions.reshape(-1)[idx]

        _z = z[idx]
        _next_z = next_z[idx]

        permutation = jax.random.permutation(rng_perm, _next_z, axis=0) # permute over the batch
        _next_z = jnp.concatenate([_next_z, permutation], axis=0)
        _z = jnp.tile(_z, (2, 1))
        ratio_logits = self.ratio(jnp.concatenate([_z, _next_z], axis=-1))
        
        return (z.reshape(*batchdims, -1), actions, next_z.reshape(*batchdims, -1)), (inverse_logits, ratio_logits, mask.reshape(*batchdims))
    
    def loss_fn(self, prioritized=False):
        def _loss(
            z,
            actions,
            next_z,
            inverse_logits,
            ratio_logits,
            mask,
            importance_weights
        ):

            smoothness_loss = jnp.mean(nnx.relu((jnp.sum((next_z-z)**2, axis=-1) - self.config.smoothness_thresh**2)) * mask)
            inverse_loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(inverse_logits, actions) * mask)
            ratio_labels = jnp.concatenate([jnp.ones((ratio_logits.shape[0]//2,), dtype=jnp.int32),  jnp.zeros((ratio_logits.shape[0]//2,), dtype=jnp.int32)], axis=0)
            ratio_loss = jnp.mean(optax.sigmoid_binary_cross_entropy(ratio_logits[..., 0], ratio_labels))

            loss =  self.config.params.inverse_const * inverse_loss + \
                    self.config.params.ratio_const * ratio_loss + \
                    self.config.params.smoothness_const * smoothness_loss
            logs = {
                'scalars': {
                    'smoothness_loss': smoothness_loss,
                    'inverse_loss': inverse_loss,
                    'ratio_loss': ratio_loss,
                    'loss': loss
                }
            }
            return loss, logs
        dummy_logs = {
                'scalars': {
                    'smoothness_loss': 0.,
                    'inverse_loss': 0.,
                    'ratio_loss': 0.,
                    'loss': 0.
                }
            }
        return _loss, dummy_logs




class GCLRepresentation(RepModel):

    def __init__(self, config, rngs):
        super().__init__(config, rngs)
        self.config = config

        self.rngs = rngs
        self.latent_dim = self.config.latent_dim
        self.n_factors = self.latent_dim // self.config.vars_per_factor
        self.n_vars = self.config.vars_per_factor
        self.batch_size = self.config.batch_size
        if self.config.is_pixel:
            self.config.encoder_pixel.input_shape = self.config.obs_dim
            self.config.decoder_pixel.output_shape = self.config.obs_dim
            self.config.decoder_pixel.input_dim = self.config.latent_dim
        else:
            self.config.encoder.input_dim = self.config.obs_dim[0]
            self.config.decoder.output_dim = self.config.obs_dim[0]
            self.config.decoder.input_dim = self.config.latent_dim
        self.encoder = build_model(self.config.encoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.encoder, rngs)
        self.decoder = build_model(self.config.decoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.decoder, rngs)
       
        self.config.energy.input_dim = self.config.latent_dim + self.config.vars_per_factor
        self.energies = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.energy, rngs, jnp.arange(self.n_factors))

    def get_energies(
            self,
            z,
            next_z
    ):
        '''
            Non batched energy computation
            vmap it to work with batches.
        '''
        delta_z = (next_z-z).reshape(self.n_factors, self.n_vars)
        energy_inputs = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
        energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
        return energies #(F, A)

    def forward(
            self,
            obs,
            actions,
            rewards,
            dones,
            rng,
            states=None
        ):
        rng, rng_noise = jax.random.split(rng)
        obs = self.preprocess(obs, rng_noise)
        rng_z, rng_next_z, rng_sample, rng_perm_act, rng_perm_z = jax.random.split(rng, 5)
        z = self.encoder(obs)
        z = z + jax.random.normal(rng_z, z.shape)*self.config.noise_std
        next_z = z[..., 1:, :]
        z = z[..., :-1, :]
        next_obs = obs[:, 1:]
        obs = obs[:, :-1]
        
        if len(z.shape) == 1: # add batch dim
            z = z[None]
            next_z = next_z[None]

        recons_next_x = self.decoder(next_z)
        mask = 1-dones

        batch_dims, latent_dim = z.shape[:-1], z.shape[-1]
        batch_size = self.batch_size
        _z = z.reshape(-1, latent_dim)
        _next_z = next_z.reshape(-1, latent_dim)
        _actions = actions.reshape(-1)
        dones = dones.reshape(-1)

        idx = jax.random.categorical(
            rng_sample,
            jnp.log((1-dones)/(1-dones).sum() + 1e-12), # don't sample done elements because z->z' is invalid
            axis=0,
            shape=(batch_size,)
        )

        _z = _z[idx]
        _next_z = _next_z[idx]
        _actions = _actions.reshape(-1)[idx]
        
        permuted_z = jax.random.permutation(rng_perm_z, _z, axis=0)
        permuted_a = jax.random.permutation(rng_perm_z, _actions, axis=0)

        energies_real = nnx.vmap(self.get_energies, in_axes=(0,0))(_z, _next_z).sum(1)[jnp.arange(batch_size), _actions]
        energies_fake = nnx.vmap(self.get_energies, in_axes=(0,0))(permuted_z, _next_z).sum(1)[jnp.arange(batch_size), permuted_a]
        return (z, actions, next_z), (next_obs, _actions, energies_real, energies_fake, recons_next_x, idx)

    def loss_fn(self, prioritized=False):
        def _loss(
                z,
                actions,
                next_z,
                next_obs,
                sel_actions,
                energies_real,
                energies_fake,
                recons_next_x,
                indices,
                importance_weights
            ):
            batch_dims = z.shape[:-1]
            recons_loss = jnp.mean(optax.l2_loss(next_obs, recons_next_x).reshape(*batch_dims, -1).mean(-1))

            smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()

            energy_loss = (optax.sigmoid_binary_cross_entropy(energies_real, jnp.ones_like(energies_real)).sum() + \
                        optax.sigmoid_binary_cross_entropy(energies_fake, jnp.zeros_like(energies_fake)).sum()) / (energies_real.shape[0] * 2)

            loss = self.config.params.recons_const*recons_loss + \
                    self.config.params.energy_const*energy_loss
            
            logs = {
                'scalars': {
                    'norm': (z**2).sum(-1).mean(),
                    'smoothness': smoothness,
                    'recons_loss': recons_loss,
                    'energy_loss': energy_loss,
                    'loss': loss,
                },
                    # 'histograms': {
                    #     **{'action_weights/action_{i}': weights[..., i] for i in range(self.config.n_actions)}
                    # }
            }
            return loss, logs
        
        dummy_logs = {
                'scalars': {
                        'norm': 0.,
                        'smoothness': 0.,
                        'recons_loss': 0.,
                        'energy_loss': 0.,
                    },
                    # 'histograms': {
                    #     **{'action_weights/action_{i}': 0. for i in range(self.config.n_actions)}
                    # }
            }
        
        return _loss, dummy_logs
