
from functools import partial

import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
import chex

from jaxmodels_nnx import build_model
from reps.rep_model import RepModel

from utils.printarr import printarr

sg = lambda x: jax.tree.map(jax.lax.stop_gradient, x)

class ACFRepresentation(RepModel):

    def __init__(self, config, rngs, external_encoder=None):
        super().__init__(config, rngs)
        self.config = config
        self.external_encoder = external_encoder

        self.rngs = rngs
        self.latent_dim = self.config.latent_dim
        self.n_factors = self.latent_dim // self.config.vars_per_factor
        self.n_vars = self.config.vars_per_factor
        self.batch_size = self.config.batch_size

        if not self.config.get('use_ground_truth_states', False):
            if self.config.is_pixel and external_encoder is None:
                self.config.encoder_pixel.input_shape = self.config.obs_dim
                self.config.decoder_pixel.output_shape = self.config.obs_dim
                self.config.decoder_pixel.input_dim = self.config.latent_dim
                self.encoder = build_model(self.config.encoder_pixel, rngs)
                self.decoder = build_model(self.config.decoder_pixel, rngs)
            else:
                # When external_encoder is provided, use non-pixel encoder
                # even if observations are pixels
                if self.config.is_pixel:
                    self.config.decoder_pixel.output_shape = self.config.obs_dim
                    self.config.decoder_pixel.input_dim = self.config.latent_dim
                    self.decoder = build_model(self.config.decoder_pixel, rngs)
                else:
                    self.config.decoder.output_dim = self.config.obs_dim[0]
                    self.config.decoder.input_dim = self.config.latent_dim
                    self.decoder = build_model(self.config.decoder, rngs)
                
                # For the encoder, always use the non-pixel one when external_encoder is provided
                if external_encoder is not None:
                    # The input dimension depends on what the external encoder outputs
                    # Assuming the external encoder's output matches our latent dimension requirements
                    self.config.encoder.input_dim = self.config.latent_dim
                else:
                    self.config.encoder.input_dim = self.config.obs_dim[0]
                
                self.encoder = build_model(self.config.encoder, rngs)

        if not self.config.get('inner_product', False):
            self.config.energy.input_dim = self.config.latent_dim + self.config.vars_per_factor
            self.config.energy.output_dim = self.config.n_actions
            self.energies = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.energy, rngs, jnp.arange(self.n_factors))
        else:
            embedding_dim = 32
            self.config.context_rep = self.config.energy
            self.config.context_rep.input_dim = self.config.latent_dim
            self.config.context_rep.output_dim = embedding_dim * self.config.n_actions
            self.config.next_state_rep = self.config.energy
            self.config.next_state_rep.input_dim = self.config.vars_per_factor
            self.config.next_state_rep.output_dim = embedding_dim
            self.context_rep = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.context_rep, rngs, jnp.arange(self.n_factors))
            self.next_state_rep = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.next_state_rep, rngs, jnp.arange(self.n_factors))

        self.config.pi.input_dim = self.latent_dim
        self.pi = build_model(self.config.pi, rngs)

        self.config.projector.input_dim = self.latent_dim
        self.config.projector.output_dim = self.latent_dim
        self.projector = build_model(self.config.projector, rngs)
    
    def get_energies(
            self,
            z,
            next_z
    ):
        '''
            Non batched energy computation
            vmap it to work with batches.
        '''
        delta_z = (next_z - z).reshape(self.n_factors, self.n_vars)
        energy_inputs = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
        energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
        return energies #(F, A)
    
    def encode(self, obs, states=None):
        if self.config.get('use_ground_truth_states', False):
            if states is None:
                raise ValueError('States must be provided if use_ground_truth_states is True')
            return states
        else:
            if self.external_encoder is not None:
                # First apply external encoder, then internal encoder
                encoded_obs = self.external_encoder(obs)
                return self.encoder(encoded_obs)
            else:
                return self.encoder(obs)
    
    def infer_actions(
            self,
            obs,
            states=None,
            rng=jax.random.PRNGKey(0)
    ):
        '''
            Infer action from single transition
            vmap it to work with batches.
            obs: (T, D)
        '''
        z = self.encode(obs, states)
        next_z = z[1:]
        z = z[:-1]
        if not self.config.get('inner_product', False):
            energies = nnx.vmap(self.get_energies, in_axes=(0, 0))(z, next_z).sum(1) # (T-1, F, A)
        else:
            contexts = nnx.vmap(lambda context: context(z).reshape(z.shape[0], self.config.n_actions, -1), in_axes=(0,))(self.context_rep).swapaxes(0,1)
            next_states = nnx.vmap(lambda factor, input: factor(input), in_axes=(0, 1))(self.next_state_rep, next_z.reshape(-1, self.n_factors, self.n_vars)).swapaxes(0,1)
            energies = jnp.einsum('bfd, bfad -> bfa', next_states, contexts).sum(1)

        actions_probs = jax.nn.softmax(energies) # full inverse
        binary_classifiers = jax.nn.sigmoid(
            energies[:, 1:] - energies[:, :1]
        ) # binary classifiers
        return actions_probs, binary_classifiers

    def forward(
            self,
            obs,
            actions,
            rewards,
            dones,
            rng,
            states=None
        ):
        rng_z, rng_next_z, rng_sample, rng_obs_noise = jax.random.split(rng, 4)
        obs = self.preprocess(obs, rng_obs_noise)
        if states is not None and self.config.get('use_ground_truth_states', False):
            z = states[:, :-1] + jax.random.normal(rng_z, states[:, :-1].shape)*self.config.noise_std
            next_z = states[:, 1:] + jax.random.normal(rng_next_z, z.shape)*self.config.noise_std
        elif self.config.get('use_ground_truth_states', False):
            raise NotImplementedError("Ground truth states not provided in forward pass")
        elif self.external_encoder is not None:
            encoded_obs = sg(self.external_encoder(obs))
            z = self.encoder(encoded_obs)
            z = z + jax.random.normal(rng_z, z.shape)*self.config.noise_std
            next_z = z[..., 1:, :]
            z = z[..., :-1, :]
        else:
            z = self.encoder(obs)
            z = z + jax.random.normal(rng_z, z.shape)*self.config.noise_std
            next_z = z[..., 1:, :]
            z = z[..., :-1, :]

        next_obs = obs[:, 1:]
        obs = obs[:, :-1]


        if len(z.shape) == 1: # add batch dim
            z = z[None]
            next_z = next_z[None]

        if self.config.params.recons_const > 0:
            recons_next_x = self.decoder(next_z) if not self.config.get('use_ground_truth_states', False) else next_obs
        else:
            recons_next_x = 0.
        action_probs = nnx.log_softmax(self.pi(z))

        batch_dims, latent_dim = z.shape[:-1], z.shape[-1]
        batch_size = self.batch_size
        _z = z.reshape(-1, latent_dim)
        _next_z = next_z.reshape(-1, latent_dim)
        _actions = actions.reshape(-1)
        dones = dones.reshape(-1)
        mask = dones == 0
        batch_size = _z.shape[0]
        idx = jnp.arange(_z.shape[0])

        _z = _z[idx]
        _next_z = _next_z[idx]
        _actions = _actions.reshape(-1)[idx]
        

        # randomize the z' to estimate the forward dynamics.
        if not self.config.get('inner_product', False):
            energies = nnx.vmap(
                    nnx.vmap(
                        self.get_energies,
                        in_axes=(0, None)
                    ),
                    in_axes=(None, 0)
                )(_z, _next_z) # (z', z)
        else:
            contexts = nnx.vmap(lambda context: context(_z).reshape(_z.shape[0], self.config.n_actions, -1), in_axes=(0,))(self.context_rep).swapaxes(0,1)
            next_states = nnx.vmap(lambda factor, input: factor(input), in_axes=(0, 1))(self.next_state_rep, _next_z.reshape(-1, self.n_factors, self.n_vars)).swapaxes(0,1)
            energies = jnp.einsum('bfd, cfad -> bcfa', next_states, contexts)


        return (z, actions, next_z), (next_obs, _actions, energies, recons_next_x, action_probs, idx, mask, rng_sample)

    def loss_fn(self, prioritized=False):
        def _loss(
                z,
                actions,
                next_z,
                next_obs,
                sel_actions,
                energies,
                recons_next_x,
                action_prob_logits,
                indices,
                mask,
                rng,
                importance_weights,
            ):
            batch_dims = z.shape[:-1]
            importance_weights = importance_weights[:, None].repeat(batch_dims[-1], axis=-1)
            if self.config.params.recons_const > 0:
                recons_loss = jnp.mean(optax.l2_loss(next_obs, recons_next_x).reshape(*batch_dims, -1).mean(-1) * importance_weights)
            else:
                recons_loss = 0.

            action_prob_loss = jnp.mean(
                    optax.softmax_cross_entropy_with_integer_labels(
                    action_prob_logits,
                    actions.astype(jnp.int32)
                ) * importance_weights
            )

            smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()
            _importance_weights = importance_weights.reshape(-1)[indices]

            r_energies = jnp.diagonal(energies).transpose(2, 0, 1) # energies real samples
            _action_probs = action_prob_logits.reshape(-1, action_prob_logits.shape[-1])[indices]
            inverse_model_cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
                r_energies.sum(1) + jax.lax.stop_gradient(_action_probs * float(self.config.use_action_weights)),
                sel_actions
            ) 
            inverse_model_loss = jnp.mean(inverse_model_cross_entropy * _importance_weights)


            if not self.config.use_action_weights:
                global_inverse = (r_energies - r_energies[..., :1]) # B, F, A
            else:
                _action_probs = action_prob_logits.reshape(-1, action_prob_logits.shape[-1])[indices]
                weights = jax.lax.stop_gradient(-(_action_probs - _action_probs[..., :1]))[:, None, :]
                global_inverse = ((r_energies - r_energies[..., :1]) + weights) # B, F, A

            action_mask = jax.nn.one_hot(sel_actions, self.config.n_actions)[:, None] # BxA
            inverse_bce = optax.sigmoid_binary_cross_entropy(global_inverse, action_mask) # BxA
            inverse_bce = inverse_bce[..., 1:]
            inverse_bce_true = jnp.where(action_mask[..., 1:], inverse_bce, 0.).sum(-1)
            inverse_bce_false = jnp.where(1-action_mask[..., 1:], inverse_bce, 0.).sum(-1)
            inverse_bce = jax.lax.select(
                sel_actions[:, None].repeat(self.n_factors, axis=1) == 0,
                inverse_bce_false,
                inverse_bce_true
            )
            inverse_loss_per_factor = jnp.mean(inverse_bce * _importance_weights[:, None])
            
            if self.config.get('simple_acf', True):
                if not self.config.use_action_weights:
                    global_inverse = (r_energies - r_energies[..., :1]).sum(1)
                else:
                    _action_probs = action_prob_logits.reshape(-1, action_prob_logits.shape[-1])[indices]
                    weights = jax.lax.stop_gradient(-(_action_probs - _action_probs[..., :1]))
                    global_inverse = ((r_energies - r_energies[..., :1]).sum(1) + weights)# B, A

                action_mask = jax.nn.one_hot(sel_actions, self.config.n_actions) # BxAx1
                inverse_bce = optax.sigmoid_binary_cross_entropy(global_inverse, action_mask) # B x A
                # inverse_loss = (inverse_bce * mask[:, None] * _importance_weights[:, None]).mean()
                inverse_bce = inverse_bce[..., 1:]
                inverse_bce_true = jnp.where(action_mask[..., 1:], inverse_bce, 0.).sum(-1)
                inverse_bce_false = jnp.where(1-action_mask[..., 1:], inverse_bce, 0.).sum(-1)
                inverse_bce = jax.lax.select(
                    sel_actions == 0,
                    inverse_bce_false,
                    inverse_bce_true
                )
                inverse_loss = jnp.mean(inverse_bce * _importance_weights)
            else:
                __energies = r_energies.sum(1) # B, A
                _energies_a = jnp.take_along_axis(__energies, sel_actions[:, None], axis=-1)
                if not self.config.use_action_weights:
                    global_inverse = _energies_a - __energies 
                else:
                    _action_probs = action_prob_logits.reshape(-1, action_prob_logits.shape[-1])[indices]
                    _actions_a = jnp.take_along_axis(_action_probs, sel_actions[:, None], axis=-1)
                    weights = jax.lax.stop_gradient(-(_actions_a - _action_probs))
                    global_inverse = (_energies_a - __energies + weights) # B, A

                inverse_loss = jnp.mean(optax.sigmoid_binary_cross_entropy(global_inverse, jnp.ones_like(global_inverse)) * mask[:, None] * _importance_weights[:, None])


            r_energies = energies.sum(2)
            if self.config.get('original_forward_loss', True) and not self.config.get('inner_product', False):
                forward_loss = jnp.diagonal(r_energies).transpose(1,0) - \
                                    jax.nn.logsumexp(r_energies, axis=1)
                forward_loss = jnp.take_along_axis(forward_loss, sel_actions[:, None], axis=-1)[:, -1]
            else:
                # Energies are (z', z, a)
                real_energies = jnp.take_along_axis(jnp.diagonal(r_energies).transpose(1,0), sel_actions[:, None], axis=-1)[..., -1]
                denominator = jnp.take_along_axis(r_energies, sel_actions[None, :, None], axis=-1)
                denominator = jax.nn.logsumexp(denominator[..., -1], axis=-1)
                forward_loss = real_energies - denominator

            # per action info_nce loss
            # Compute per-action InfoNCE loss
            # For each action, select only the elements that coincide with that action
            per_action_forward_loss = 0
            real_energies = jnp.take_along_axis(jnp.diagonal(r_energies).transpose(1,0), sel_actions[:, None], axis=-1)[..., -1]
            denominator = jnp.take_along_axis(r_energies, sel_actions[None, :, None], axis=-1)[..., -1]
            for a in range(self.config.n_actions):
                # Create mask for samples with this action
                action_mask = (sel_actions == a)
                action_real_energies = real_energies
                action_denominator = jax.nn.logsumexp(jnp.where(action_mask[None], denominator, -1e12), axis=-1)
                action_mask_values = mask & action_mask
                
                # Compute InfoNCE loss for this action
                action_forward = action_real_energies - action_denominator
                action_loss = -jnp.sum(action_forward * action_mask_values) / (jnp.sum(action_mask_values) + 1e-8)
                
                per_action_forward_loss = per_action_forward_loss + action_loss
            
            # Store per-action losses for logging if needed
            per_action_forward_loss = per_action_forward_loss / self.config.n_actions * _importance_weights

            # per action info_nce
            forward_loss = -jnp.mean(forward_loss * mask * _importance_weights)

            
            
            _z = z.reshape(-1, z.shape[-1])
            grounding_loss = -((_z[None] - _z[:, None]) ** 2).sum(-1) / 0.1**2
            grounding_loss = -(jnp.diagonal(grounding_loss) - jax.nn.logsumexp(grounding_loss, axis=-1))
            grounding_loss = jnp.mean(grounding_loss * _importance_weights)


            loss = self.config.params.recons_const*recons_loss + \
                    self.config.params.inverse_const*inverse_loss + \
                    self.config.params.get('inverse_model_const', 0.)*inverse_model_loss + \
                    self.config.params.get('inverse_per_factor_const', 0.)*inverse_loss_per_factor + \
                    self.config.params.forward_const*forward_loss + \
                    (float(self.config.use_action_weights)*self.config.params.policy_const)*action_prob_loss + \
                    self.config.get('grounding_const', 0.)*grounding_loss + \
                    self.config.params.get('per_action_forward_const', 0.) * per_action_forward_loss.mean()
            
            logs = {
                'scalars': {
                    'norm': (z**2).sum(-1).mean(),
                    'smoothness': smoothness,
                    'recons_loss': recons_loss,
                    'inverse_loss': inverse_loss,
                    'inverse_per_factor_loss': inverse_loss_per_factor,
                    'inverse_model_loss': inverse_model_loss,
                    'forward_loss': forward_loss,
                    'policy_loss': action_prob_loss,
                    'grounding_loss': grounding_loss,
                    'per_action_forward_loss': per_action_forward_loss.mean(),
                    'loss': loss,
                },
                    # 'histograms': {
                    #     **{'action_weights/action_{i}': weights[..., i] for i in range(self.config.n_actions)}
                    # }
            }
            if prioritized:
                if self.config.get('full_inverse_priority', False):
                    print(f'Using full inverse priority')
                    priorities = (inverse_bce * mask[:, None]).sum(1) + inverse_model_cross_entropy
                else:
                    priorities = (inverse_bce * mask[:, None]).sum(1) 
                return loss, logs, priorities
            else:
                return loss, logs
        
        dummy_logs = {
                'scalars': {
                        'norm': 0.,
                        'smoothness': 0.,
                        'recons_loss': 0.,
                        'inverse_loss': 0.,
                        'forward_loss': 0.,
                        'policy_loss': 0.,
                        'loss': 0.,
                    },
                    # 'histograms': {
                    #     **{'action_weights/action_{i}': 0. for i in range(self.config.n_actions)}
                    # }
            }
        
        return _loss, dummy_logs
    

class MultistepACFRepresentation(RepModel):
    def __init__(self, config, rngs):
        super().__init__(config, rngs)
        config.acf_config.is_pixel = config.is_pixel
        config.acf_config.obs_dim = config.obs_dim
        self.config = config

        self.rngs = rngs
        self.latent_dim = self.config.acf_config.latent_dim
        self.n_factors = self.latent_dim // self.config.acf_config.vars_per_factor
        self.n_vars = self.config.acf_config.vars_per_factor
        self.batch_size = self.config.acf_config.batch_size
        if self.config.is_pixel:
            self.config.acf_config.encoder_pixel.input_shape = self.config.obs_dim
            self.config.acf_config.decoder_pixel.output_shape = self.config.obs_dim
            self.config.acf_config.decoder_pixel.input_dim = self.config.latent_dim
        else:
            self.config.acf_config.encoder.input_dim = self.config.obs_dim[0]
            self.config.acf_config.decoder.output_dim = self.config.obs_dim[0]
            self.config.acf_config.decoder.input_dim = self.config.latent_dim
        self.encoder = build_model(self.config.acf_config.encoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.acf_config.encoder, rngs)
        self.decoder = build_model(self.config.acf_config.decoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.acf_config.decoder, rngs)
        
        self.config.acf_config.energy.input_dim = self.config.latent_dim + self.config.vars_per_factor
        self.config.acf_config.pi.input_dim = self.config.latent_dim

        self.energies = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.acf_config.energy, rngs, jnp.arange(self.n_factors))
        self.pi = build_model(self.config.acf_config.pi, rngs)

        self.config.acf_config.projector.input_dim = self.latent_dim
        self.config.acf_config.projector.output_dim = self.latent_dim
        self.projector = build_model(self.config.acf_config.projector, rngs)

        self.offset_embedding = nnx.Embed(self.config.embed.max_offset, self.config.embed.dim, rngs=rngs)
        self.config.multistep_classifier.input_dim = self.config.embed.dim + self.config.latent_dim*2
        self.multistep_classifier = build_model(self.config.multistep_classifier, rngs)
    
    def get_energies(
            self,
            z,
            next_z
    ):
        '''
            Non batched energy computation
            vmap it to work with batches.
        '''
        delta_z = (next_z - z).reshape(self.n_factors, self.n_vars)
        energy_inputs = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
        energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
        return energies #(F, A)

    def infer_actions(
            self,
            obs,
            states=None,
            rng=jax.random.PRNGKey(0)
    ):
        '''
            Infer action from single transition
            vmap it to work with batches.
            obs: (T, D)
        '''
        z = self.encode(obs, states)
        next_z = z[1:]
        z = z[:-1]
        energies = nnx.vmap(self.get_energies, in_axes=(0, 0))(z, next_z).sum(1) # (T-1, F, A)
        actions_probs = jax.nn.softmax(energies) # full inverse
        binary_classifiers = jax.nn.sigmoid(
            energies[:, 1:] - energies[:, :1]
        ) # binary classifiers
        return actions_probs, binary_classifiers

    def forward(self, obs, actions, rewards, dones, rng, states=None):
        rng_acf, rng = jax.random.split(rng, 2)
        
        rng_z, rng_next_z, rng_sample = jax.random.split(rng, 3)
        z = self.encoder(obs)
        z = z + jax.random.normal(rng_z, z.shape)*self.config.acf_config.noise_std
        next_z = z[..., 1:, :]
        z = z[..., :-1, :]
        
        if len(z.shape) == 1: # add batch dim
            z = z[None]
            next_z = next_z[None]

        recons_next_x = 0. #self.decoder(next_z)
        action_probs = nnx.log_softmax(self.pi(z))

        batch_dims, latent_dim = z.shape[:-1], z.shape[-1]
        batch_size = self.batch_size
        _z = z.reshape(-1, latent_dim)
        _next_z = next_z.reshape(-1, latent_dim)
        _actions = actions.reshape(-1)
        _dones = dones.reshape(-1)
        mask = 1-_dones

        batch_size = _z.shape[0]
        idx = jnp.arange(_z.shape[0])

        _z = _z[idx]
        _next_z = _next_z[idx]
        _actions = _actions.reshape(-1)[idx]
        # randomize the z' to estimate the forward dynamics.
        energies = nnx.vmap(
                nnx.vmap(
                    self.get_energies,
                    in_axes=(0, None)
                ),
                in_axes=(None, 0)
            )(_z, _next_z)
        


        acf_forward_outs = (obs[:, 1:], _actions, energies, recons_next_x, action_probs, idx, mask, rng_sample)
        # sample offset 
        traj_masks = jnp.right_shift(dones.astype(jnp.int32), 1)
        traj_masks = jnp.cumsum(traj_masks, axis=-1)
        B, T = z.shape[:2]
        # sample t, sample offset
        def _sample_offset_and_t(z, traj_masks, rng):
            rng, rng_offset, rng_t = jax.random.split(rng, 3)
            t = jax.random.randint(rng_t, (1,), 0, T-1)[0]
            cats =  (traj_masks[t] == traj_masks) & \
                    (jnp.arange(T) > t) & \
                    (jnp.arange(T) < t+self.config.embed.max_offset)
            cats = cats.astype(jnp.float32)
            offset = jax.random.categorical(rng_offset, cats)
            return t, offset, z[t], z[t+offset]

        # def _sample_offset_and_t(z, traj_masks, rng):
        #     return 0, 0, z[0], z[1]

        t, offset, current_z, future_z = nnx.vmap(
            _sample_offset_and_t,
            in_axes=(0, 0, 0)
        )(z, traj_masks, jax.random.split(rng, z.shape[0]))
        
        embed_offset = self.offset_embedding(jax.lax.stop_gradient(offset-1)) #
        # embed_offset = jnp.zeros((B, self.config.embed.dim))
        multistep_input = jnp.concatenate([current_z, future_z, embed_offset], axis=-1)
        multistep_logits = self.multistep_classifier(multistep_input)
        k_step_actions = actions[jnp.arange(B), jax.lax.stop_gradient(t)]
        return (z, actions, next_z), (multistep_logits, k_step_actions, acf_forward_outs)

    def loss_fn(self, prioritized=False):

        def _loss(
            z,
            actions,
            next_z,
            multistep_logits,
            k_step_actions,
            acf_forward_outs,
            importance_weights,
        ):
            (next_obs, sel_actions, energies, recons_next_x, action_probs, indices, mask, rng) = acf_forward_outs
            
            batch_dims = z.shape[:-1]
            importance_weights = importance_weights[:, None].repeat(batch_dims[-1], axis=-1)
            recons_loss = 0.#jnp.mean(optax.l2_loss(next_obs, recons_next_x).reshape(*batch_dims, -1).mean(-1))

            action_prob_loss = jnp.mean(
                    optax.softmax_cross_entropy_with_integer_labels(
                    action_probs,
                    actions.astype(jnp.int32)
                ) * importance_weights
            )

            smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()
            _importance_weights = importance_weights.reshape(-1)[indices]

            r_energies = jnp.diagonal(energies).transpose(2, 0, 1) # energies real samples
            inverse_model_loss = (optax.softmax_cross_entropy_with_integer_labels(
                r_energies.sum(1),
                sel_actions
            )).mean()


            if not self.config.acf_config.use_action_weights:
                global_inverse = (r_energies - r_energies[..., :1]).sum(1)
            else:
                _action_probs = action_probs.reshape(-1, action_probs.shape[-1])[indices]
                weights = jax.lax.stop_gradient(-(_action_probs - _action_probs[..., :1]))
                global_inverse = ((r_energies - r_energies[..., :1]).sum(1) + weights)# B, A

            r_energies = energies.sum(2)
            forward_loss = jnp.diagonal(r_energies).transpose(1,0) - \
                                jax.nn.logsumexp(r_energies, axis=1)
            forward_loss = jnp.take_along_axis(forward_loss, sel_actions[:, None], axis=-1)[:, -1]

            forward_loss = -jnp.mean(forward_loss * mask)

            action_mask = jax.nn.one_hot(sel_actions, self.config.acf_config.n_actions) # BxAx1
            inverse_loss = (optax.sigmoid_binary_cross_entropy(global_inverse, action_mask) * mask[:, None]).mean()

            loss = self.config.acf_config.params.recons_const*recons_loss + \
                    self.config.acf_config.params.inverse_const*inverse_loss + \
                    self.config.acf_config.params.get('inverse_model_const', 0.)*inverse_model_loss + \
                    self.config.acf_config.params.forward_const*forward_loss + \
                    (self.config.acf_config.use_action_weights*self.config.acf_config.params.policy_const)*action_prob_loss
            

            multistep_loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(
                multistep_logits,
                k_step_actions.astype(jnp.int32)
            ))
            # jax.debug.print("{}", (multistep_logits.mean(), k_step_actions.max(), k_step_actions.min()))

            loss = loss + multistep_loss * self.config.params.multistep_inv_const
            logs = {
                'scalars': {
                    'norm': (z**2).sum(-1).mean(),
                    'smoothness': smoothness,
                    'recons_loss': recons_loss,
                    'inverse_loss': inverse_loss,
                    'inverse_model_loss': inverse_model_loss,
                    'forward_loss': forward_loss,
                    'policy_loss': action_prob_loss,
                    'multistep_inv_loss': multistep_loss,
                    'loss': loss,
                },
                    # 'histograms': {
                    #     **{'action_weights/action_{i}': weights[..., i] for i in range(self.config.n_actions)}
                    # }
            }
            

            return loss, logs

        dummy_logs = {
                'scalars': {
                        'norm': 0.,
                        'smoothness': 0.,
                        'recons_loss': 0.,
                        'inverse_loss': 0.,
                        'forward_loss': 0.,
                        'policy_loss': 0.,
                        'loss': 0.,
                        'multistep_inv_loss': 0.,
                    },
                    # 'histograms': {
                    #     **{'action_weights/action_{i}': 0. for i in range(self.config.n_actions)}
                    # }
            }
        
        return _loss, dummy_logs
            
# class PerFactorACFRepresentation(RepModel):

#     def __init__(self, config, rngs):
#         super().__init__(config, rngs)
#         self.config = config

#         self.rngs = rngs
#         self.latent_dim = self.config.latent_dim
#         self.n_factors = self.latent_dim // self.config.vars_per_factor
#         self.n_vars = self.config.vars_per_factor
#         self.batch_size = self.config.batch_size
#         if self.config.is_pixel:
#             self.config.encoder_pixel.input_shape = self.config.obs_dim
#             self.config.decoder_pixel.output_shape = self.config.obs_dim
#             self.config.decoder_pixel.input_dim = self.config.latent_dim
#         else:
#             self.config.encoder.input_dim = self.config.obs_dim[0]
#             self.config.decoder.output_dim = self.config.obs_dim[0]
#             self.config.decoder.input_dim = self.config.latent_dim
#         self.encoder = build_model(self.config.encoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.encoder, rngs)
#         self.decoder = build_model(self.config.decoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.decoder, rngs)
        
#         self.config.energy.input_dim = self.config.latent_dim + self.config.vars_per_factor
#         self.config.pi.input_dim = self.config.latent_dim

#         self.energies = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.energy, rngs, jnp.arange(self.n_factors))
#         self.pi = build_model(self.config.pi, rngs)

#         self.config.projector.input_dim = self.latent_dim
#         self.config.projector.output_dim = self.latent_dim
#         self.projector = build_model(self.config.projector, rngs)

#     def get_energies(
#             self,
#             z,
#             next_z
#     ):
#         '''
#             Non batched energy computation
#             vmap it to work with batches.
#         '''
#         delta_z = (next_z - z).reshape(self.n_factors, self.n_vars)
#         energy_inputs = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
#         energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
#         return energies #(F, A)

#     def get_energies_2(
#             self,
#             z,
#             action,
#             next_z
#     ):
#         '''
#             Non batched energy computation
#             vmap it to work with batches.
#         '''
#         delta_z = (next_z - z).reshape(self.n_factors, self.n_vars)
#         energy_inputs = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
#         energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
#         return energies[..., action] #(F,)

#     def forward(
#             self,
#             obs,
#             actions,
#             rewards,
#             next_obs,
#             dones,
#             rng,
#             states=None
#         ):

#         rng_z, rng_next_z, rng_sample, rng_perm = jax.random.split(rng, 4)
#         z = self.encoder(obs)
#         z = z + jax.random.normal(rng_z, z.shape)*self.config.noise_std
#         next_z = self.encoder(next_obs) + jax.random.normal(rng_next_z, z.shape)*self.config.noise_std
        
#         if len(z.shape) == 1: # add batch dim
#             z = z[None]
#             next_z = next_z[None]

#         recons_next_x = self.decoder(next_z)
#         # recons_x = self.decoder(z)
#         recons_x = recons_next_x
#         action_probs = nnx.log_softmax(self.pi(z))

#         batch_dims, latent_dim = z.shape[:-1], z.shape[-1]
#         batch_size = self.batch_size
#         _z = z.reshape(-1, latent_dim)
#         _next_z = next_z.reshape(-1, latent_dim)
#         _actions = actions.reshape(-1)
#         dones = dones.reshape(-1)
#         mask = 1-dones
#         # idx = jax.random.categorical(
#         #     rng_sample,
#         #     jnp.log((1-dones)/(1-dones).sum() + 1e-12), # don't sample done elements because z->z' is invalid
#         #     axis=0,
#         #     shape=(batch_size,)
#         # )

#         batch_size = _z.shape[0]
#         idx = jnp.arange(_z.shape[0])

#         _z = _z[idx]
#         _next_z = _next_z[idx]
#         _actions = _actions.reshape(-1)[idx]
#         # randomize the z' to estimate the forward dynamics.
#         # energies = nnx.vmap(
#         #         nnx.vmap(
#         #             self.get_energies,
#         #             in_axes=(0, None)
#         #         ),
#         #         in_axes=(None, 0)
#         #     )(_z, _next_z)

#         energies = nnx.vmap(
#                     self.get_energies,
#                     in_axes=(0, 0)
#                 )(_z, _next_z)

#         return (z, actions, next_z), (obs, next_obs, _actions, energies, recons_x, recons_next_x, action_probs, idx, mask, rng_sample)

#     def loss_fn(self):
#         def _loss(
#                 z,
#                 actions,
#                 next_z,
#                 obs,
#                 next_obs,
#                 sel_actions,
#                 energies,
#                 reconst_x,
#                 recons_next_x,
#                 action_probs,
#                 indices,
#                 mask,
#                 rng,
#                 importance_weights,
#             ):
#             batch_dims = z.shape[:-1]
#             importance_weights = importance_weights[:, None].repeat(batch_dims[-1], axis=-1)
#             recons_loss = jnp.mean(optax.l2_loss(next_obs, recons_next_x).reshape(*batch_dims, -1).mean(-1))
#                             #jnp.mean(optax.l2_loss(obs, reconst_x).reshape(*batch_dims, -1).mean(-1))

#             action_prob_loss = jnp.mean(
#                     optax.softmax_cross_entropy_with_integer_labels(
#                     action_probs,
#                     actions.astype(jnp.int32)
#                 ) * importance_weights
#             )

#             smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()
#             _importance_weights = importance_weights.reshape(-1)[indices]

#             # r_energies = jnp.diagonal(energies).transpose(2, 0, 1) # energies real samples
#             r_energies = energies
#             if not self.config.use_action_weights:
#                 global_inverse = (r_energies - r_energies[..., :1])
#                 global_inverse = r_energies[..., None] - r_energies[..., None, :] # B, F, A, A
#             else:
#                 _action_probs = action_probs.reshape(-1, action_probs.shape[-1])[indices]
#                 weights = jax.lax.stop_gradient(-(_action_probs - _action_probs[..., :1]))
#                 global_inverse = ((r_energies - r_energies[..., :1]) + weights)# B, A

#             # action_mask = jax.nn.one_hot(sel_actions, self.config.n_actions)[:, None].repeat(global_inverse.shape[1], axis=1) # BxFxA
#             # action_mask = action_mask[..., None].repeat(global_inverse.shape[-1], axis=-1)
#             action_mask = jnp.tile(jax.nn.one_hot(sel_actions, self.config.n_actions)[:, None, ..., None], (1, self.n_factors, 1, self.config.n_actions))
#             printarr(global_inverse, action_mask)
#             inverse_loss = (optax.sigmoid_binary_cross_entropy(global_inverse, action_mask)).mean(axis=0).mean()
            

#             loss = self.config.params.recons_const*recons_loss + \
#                     self.config.params.inverse_const*inverse_loss + \
#                     (self.config.use_action_weights*self.config.params.policy_const)*action_prob_loss
#                     # self.config.params.forward_const*forward_loss + \
            
#             logs = {
#                 'scalars': {
#                     'norm': (z**2).sum(-1).mean(),
#                     'smoothness': smoothness,
#                     'recons_loss': recons_loss,
#                     'inverse_loss': inverse_loss,
#                     'policy_loss': action_prob_loss,
#                     'loss': loss,
#                 },
#                     # 'histograms': {
#                     #     **{'action_weights/action_{i}': weights[..., i] for i in range(self.config.n_actions)}
#                     # }
#             }
#             return loss, logs
        
#         dummy_logs = {
#                 'scalars': {
#                         'norm': 0.,
#                         'smoothness': 0.,
#                         'recons_loss': 0.,
#                         'inverse_loss': 0.,
#                         'forward_loss': 0.,
#                         'policy_loss': 0.,
#                         'loss': 0.,
#                     },
#                     # 'histograms': {
#                     #     **{'action_weights/action_{i}': 0. for i in range(self.config.n_actions)}
#                     # }
#             }
        
#         return _loss, dummy_logs

# class SparseACFRepresentation(RepModel):

#     def __init__(self, config, rngs):
#         super().__init__(config, rngs)
#         self.config = config

#         self.rngs = rngs
#         self.latent_dim = self.config.latent_dim
#         self.n_factors = self.latent_dim // self.config.vars_per_factor
#         self.n_vars = self.config.vars_per_factor
#         self.batch_size = self.config.batch_size
#         if self.config.is_pixel:
#             self.config.encoder_pixel.input_shape = self.config.obs_dim
#             self.config.decoder_pixel.output_shape = self.config.obs_dim
#             self.config.decoder_pixel.input_dim = self.config.latent_dim
#         else:
#             self.config.encoder.input_dim = self.config.obs_dim[0]
#             self.config.decoder.output_dim = self.config.obs_dim[0]
#             self.config.decoder.input_dim = self.config.latent_dim
#         self.encoder = build_model(self.config.encoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.encoder, rngs)
#         self.decoder = build_model(self.config.decoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.decoder, rngs)
        
#         self.config.energy.input_dim = self.config.latent_dim + self.config.vars_per_factor
#         self.config.pi.input_dim = self.config.latent_dim

#         self.energies = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.energy, rngs, jnp.arange(self.n_factors))
#         self.pi = build_model(self.config.pi, rngs)

#         self.config.projector.input_dim = self.latent_dim
#         self.config.projector.output_dim = self.latent_dim
#         self.projector = build_model(self.config.projector, rngs)

#     def get_energies(
#             self,
#             z,
#             next_z
#     ):
#         '''
#             Non batched energy computation
#             vmap it to work with batches.
#         '''
#         # delta_z = (next_z - z).reshape(self.n_factors, self.n_vars)
#         delta_z = next_z.reshape(self.n_factors, self.n_vars)
#         energy_inputs = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
#         energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
#         return energies ** 2 / 2 #(F, A)

#     def get_score(
#          self,
#          z,
#          action,
#          next_z
#     ):
#         # delta_z = (next_z - z).reshape(self.n_factors, self.n_vars)
#         delta_z = next_z.reshape(self.n_factors, self.n_vars)
#         energy_inputs = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
#         energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
#         return energies[:, action]
        

#     def sample_obs_step(self,
#                    x,
#                    z,
#                    action,
#                    rng,
#                    alpha=5.
#                 ):
#         next_z, vjp_fn = jax.vjp(self.encoder, x)
#         score = vjp_fn(self.get_score(z, action, next_z))
#         noise = jax.random.normal(rng, x.shape)
#         next_x = x + alpha * score[0] + jnp.sqrt(2*alpha) * noise
#         return next_x

#     def forward(
#             self,
#             obs,
#             actions,
#             rewards,
#             dones,
#             rng,
#             states=None
#         ):

#         rng_z, rng_next_z, rng_sample, rng_perm = jax.random.split(rng, 4)
#         z = self.encoder(obs)
#         z = z + jax.random.normal(rng_z, z.shape)*self.config.noise_std
#         next_z = z[..., 1:, :]
#         z = z[..., :-1, :]
        
#         if len(z.shape) == 1: # add batch dim
#             z = z[None]
#             next_z = next_z[None]

#         recons_next_x = self.decoder(next_z)
#         action_probs = nnx.log_softmax(self.pi(z))
        

#         batch_dims, latent_dim = z.shape[:-1], z.shape[-1]
#         batch_size = self.batch_size
#         _z = z.reshape(-1, latent_dim)
#         _next_z = next_z.reshape(-1, latent_dim)
#         _actions = actions.reshape(-1)
#         dones = dones.reshape(-1)
#         mask = 1-dones

#         batch_size = _z.shape[0]
#         idx = jnp.arange(_z.shape[0])

#         _z = _z[idx]
#         _next_z = _next_z[idx]
#         _actions = _actions.reshape(-1)[idx]


#         _obs = obs.reshape(-1, *obs.shape[2:])

#         def sample_obs(state, unused):
#             _obs, rng, model = state
#             model = nnx.merge(*model)
#             rng, rng_next = jax.random.split(rng)
#             next_obs = nnx.vmap(model.sample_obs_step, in_axes=(0, 0, 0, 0))(sg(_obs), sg(_z), _actions, jax.random.split(rng_next, _z.shape[0]))

#             return (next_obs, rng, nnx.split(model)), next_obs

#         (next_x_sample, _, _), _ =  jax.lax.scan(
#             sample_obs,
#             (_obs, rng_perm, nnx.split(self)),
#             jnp.arange(3)
#         )

#         # next_x_sample = nnx.vmap(self.sample_obs_step, in_axes=(0, 0, 0, 0))(_obs, _z, _actions, jax.random.split(rng_perm, _z.shape[0]))

#         energies = nnx.vmap(
#                 nnx.vmap(
#                     self.get_energies,
#                     in_axes=(0, None)
#                 ),
#                 in_axes=(None, 0)
#             )(_z, _next_z)

#         return (z, actions, next_z), (next_obs, _actions, energies, recons_next_x, action_probs, idx, mask, rng_sample, next_x_sample)

#     def loss_fn(self):
#         def _loss(
#                 z,
#                 actions,
#                 next_z,
#                 next_obs,
#                 sel_actions,
#                 energies,
#                 recons_next_x,
#                 action_probs,
#                 indices,
#                 mask,
#                 rng,
#                 next_x_sample,
#                 importance_weights,
#             ):
#             batch_dims = z.shape[:-1]
#             importance_weights = importance_weights[:, None].repeat(batch_dims[-1], axis=-1)
#             recons_loss = jnp.mean(optax.l2_loss(next_obs, recons_next_x).reshape(*batch_dims, -1).mean(-1))
#             score_loss = jnp.mean(optax.l2_loss(next_obs.reshape(*next_x_sample.shape), next_x_sample).reshape(*batch_dims, -1).mean(-1))


#             action_prob_loss = jnp.mean(
#                     optax.softmax_cross_entropy_with_integer_labels(
#                     action_probs,
#                     actions.astype(jnp.int32)
#                 ) * importance_weights
#             )

#             smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()
#             _importance_weights = importance_weights.reshape(-1)[indices]

#             r_energies = jnp.diagonal(energies).transpose(2, 0, 1) # energies real samples
#             sparsity_loss = jnp.abs(jnp.take_along_axis(r_energies, sel_actions[:,None,None], axis=-1)[:, -1] - r_energies[..., 0]).mean()

#             if not self.config.use_action_weights:
#                 global_inverse = (r_energies - r_energies[..., :1]).sum(1)
#             else:
#                 _action_probs = action_probs.reshape(-1, action_probs.shape[-1])[indices]
#                 weights = jax.lax.stop_gradient(-(_action_probs - _action_probs[..., :1]))
#                 global_inverse = ((r_energies - r_energies[..., :1]).sum(1) + weights)# B, A

#             r_energies = energies.sum(2)
#             forward_loss = jnp.diagonal(r_energies).transpose(1,0) - \
#                                 jax.nn.logsumexp(r_energies, axis=1)
#             forward_loss = jnp.take_along_axis(forward_loss, sel_actions[:, None], axis=-1)[:, -1]

#             forward_loss = -jnp.mean(forward_loss * mask)
                
#             action_mask = jax.nn.one_hot(sel_actions, self.config.n_actions) # BxAx1
#             inverse_loss = (optax.sigmoid_binary_cross_entropy(global_inverse, action_mask) * mask[:, None]).mean()

#             loss = self.config.params.recons_const*recons_loss + \
#                     self.config.params.inverse_const*inverse_loss + \
#                     self.config.params.forward_const*forward_loss + \
#                     (self.config.use_action_weights*self.config.params.policy_const)*action_prob_loss + \
#                     1*score_loss + \
#                     0.01*sparsity_loss
            
#             logs = {
#                 'scalars': {
#                     'norm': (z**2).sum(-1).mean(),
#                     'smoothness': smoothness,
#                     'recons_loss': recons_loss,
#                     'inverse_loss': inverse_loss,
#                     'forward_loss': forward_loss,
#                     'policy_loss': action_prob_loss,
#                     'score_loss': score_loss,
#                     'sparsity_loss': sparsity_loss,
#                     'loss': loss,
#                 },
#                     # 'histograms': {
#                     #     **{'action_weights/action_{i}': weights[..., i] for i in range(self.config.n_actions)}
#                     # }
#             }
#             return loss, logs
        
#         dummy_logs = {
#                 'scalars': {
#                         'norm': 0.,
#                         'smoothness': 0.,
#                         'recons_loss': 0.,
#                         'inverse_loss': 0.,
#                         'forward_loss': 0.,
#                         'policy_loss': 0.,
#                         'loss': 0.,
#                     },
#                     # 'histograms': {
#                     #     **{'action_weights/action_{i}': 0. for i in range(self.config.n_actions)}
#                     # }
#             }
        
#         return _loss, dummy_logs

class EMACFRepresentation(RepModel):

    def __init__(self, config, rngs):
        super().__init__(config, rngs)
        self.config = config

        self.rngs = rngs
        self.latent_dim = self.config.latent_dim
        self.n_factors = self.latent_dim // self.config.vars_per_factor
        self.n_vars = self.config.vars_per_factor
        self.batch_size = self.config.batch_size
        if self.config.is_pixel:
            self.config.encoder_pixel.input_shape = self.config.obs_dim
            self.config.decoder_pixel.output_shape = self.config.obs_dim
            self.config.decoder_pixel.input_dim = self.config.latent_dim
        else:
            self.config.encoder.input_dim = self.config.obs_dim[0]
            self.config.decoder.output_dim = self.config.obs_dim[0]
            self.config.decoder.input_dim = self.config.latent_dim
        self.encoder = build_model(self.config.encoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.encoder, rngs)
        self.decoder = build_model(self.config.decoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.decoder, rngs)
        
        self.config.energy.input_dim = self.config.latent_dim + self.config.vars_per_factor
        self.config.pi.input_dim = self.config.latent_dim

        self.energies = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.energy, rngs, jnp.arange(self.n_factors))
        self.config.dynamics.input_dim = self.config.latent_dim
        self.dynamics = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.dynamics, rngs, jnp.arange(self.n_factors))
        self.pi = build_model(self.config.pi, rngs)

        self.config.projector.input_dim = self.latent_dim
        self.config.projector.output_dim = self.latent_dim
        self.projector = build_model(self.config.projector, rngs)
        self.W = nnx.Linear(self.config.latent_dim, self.config.latent_dim, use_bias=False, rngs=rngs)

    def get_energies(
            self,
            z,
            next_z
    ):
        '''
            Non batched energy computation
            vmap it to work with batches.
        '''
        delta_z = (next_z - z).reshape(self.n_factors, self.n_vars)
        energy_inputs = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
        energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
        return 0.5 * energies ** 2  #(F, A)
        
        # delta_z = next_z-z
        # sigma = 5e-2
        # delta_z_pred = nnx.vmap(lambda dyn, input: dyn(input), in_axes=(0,None))(self.dynamics, z)
        # # energies = ((delta_z[:, None] - delta_z_pred) ** 2) / (2*sigma**2)
        # energies = self.W(delta_z)[:, None] * delta_z_pred
        # return energies

    
    def predict(
            self,
            z,
            action,
            alpha=1e-2,
            steps=10,
            rng=0,
    ):
        # next_z_a = nnx.vmap(lambda dyn, input: dyn(input), in_axes=(0, None))(self.dynamics, z) # (F, A)
        # return next_z_a[:, action]
        rng_init, rng = jax.random.split(rng)
        delta_z_pred = jax.random.normal(rng_init, z.shape)
        def langevin_step(delta_z, rng):
            delta_z = delta_z.reshape(self.n_factors, self.n_vars)
            score_input = jnp.concatenate([delta_z, jnp.tile(z[None], (self.n_factors, 1))], axis=-1)
            score = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, score_input)
            noise = jax.random.normal(rng, delta_z.shape)
            delta_z = delta_z - alpha * score[:, action] + jnp.sqrt(2*alpha) * noise
            return delta_z, delta_z

        delta_z_pred, _ = jax.lax.scan(
            langevin_step,
            delta_z_pred,
            jax.random.split(rng, steps)
        )

        return delta_z_pred

    def batch_predict(
            self,
            z,
            action,
            alpha=1e-2,
            steps=10,
            rng=0,
    ):
        rng_init, rng = jax.random.split(rng)
        delta_z_pred = jax.random.normal(rng_init, z.shape)

        def langevin_step(delta_z, rng):
            delta_z = delta_z.reshape(-1, self.n_factors, self.n_vars)
            score_input = jnp.concatenate([delta_z, jnp.tile(z[:, None], (1, self.n_factors, 1))], axis=-1)
            score = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,1))(self.energies, score_input).swapaxes(0,1)
            noise = jax.random.normal(rng, delta_z.shape)
            delta_z = delta_z + alpha * score[..., action] + jnp.sqrt(2*alpha) * noise
            return delta_z, delta_z

        delta_z_pred, _ = jax.lax.scan(
            langevin_step,
            delta_z_pred,
            jax.random.split(rng, steps)
        )

        return delta_z_pred

    def forward(
            self,
            obs,
            actions,
            rewards,
            dones,
            rng,
            states=None
        ):

        rng_z, rng_next_z, rng_sample, rng_perm = jax.random.split(rng, 4)
        z = self.encoder(obs)
        z = z + jax.random.normal(rng_z, z.shape)*self.config.noise_std
        next_z = z[..., 1:, :]
        z = z[..., :-1, :]
        
        if len(z.shape) == 1: # add batch dim
            z = z[None]
            next_z = next_z[None]

        recons_next_x = self.decoder(next_z)
        action_probs = nnx.log_softmax(self.pi(z))

        batch_dims, latent_dim = z.shape[:-1], z.shape[-1]
        dones = dones.reshape(-1)

        return (z, actions, next_z), (next_obs, action_probs, dones, recons_next_x, rng_sample)

    def loss_fn(self):
        def m_loss(
                model,
                z,
                actions,
                next_z,
                next_obs,
                action_probs,
                dones,
                recons_next_x,
                rng,
                importance_weights,
            ):
            batch_dims = z.shape[:-1]
            # importance_weights = importance_weights[:, None].repeat(batch_dims[-1], axis=-1)
            sel_actions = actions.reshape(-1)
            mask = 1-dones
            recons_loss = jnp.mean(optax.l2_loss(next_obs, recons_next_x).reshape(*batch_dims, -1).mean(-1))

            action_prob_loss = jnp.mean(
                    optax.softmax_cross_entropy_with_integer_labels(
                    action_probs.reshape(-1, action_probs.shape[-1]),
                    sel_actions.astype(jnp.int32)
                )
            )

            z = z.reshape(-1, z.shape[-1])
            next_z = next_z.reshape(-1, z.shape[-1])
            smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()
            energies = nnx.vmap(
                nnx.vmap(
                    model.get_energies,
                    in_axes=(0, None)
                ),
                in_axes=(None, 0)
            )(z, next_z)

            r_energies = jnp.diagonal(energies).transpose(2, 0, 1) # energies real samples
            if not self.config.use_action_weights:
                global_inverse = (r_energies - r_energies[..., :1]).sum(1)
            else:
                _action_probs = action_probs.reshape(-1, action_probs.shape[-1])
                weights = jax.lax.stop_gradient(-(_action_probs - _action_probs[..., :1]))
                global_inverse = ((r_energies - r_energies[..., :1]).sum(1) + weights)# B, A

            r_energies = energies.sum(2)
            forward_loss = jnp.diagonal(r_energies).transpose(1,0) - \
                                jax.nn.logsumexp(r_energies, axis=1)
            forward_loss = jnp.take_along_axis(forward_loss, sel_actions[:, None], axis=-1)[:, -1]
            forward_loss = -jnp.mean(forward_loss * mask)
            action_mask = jax.nn.one_hot(sel_actions, self.config.n_actions) # BxAx1
            inverse_loss = (optax.sigmoid_binary_cross_entropy(global_inverse, action_mask) * mask[:, None]).mean()

            loss = self.config.params.recons_const*recons_loss + \
                    self.config.params.inverse_const*inverse_loss + \
                    self.config.params.forward_const*forward_loss + \
                    0.0*(jnp.abs(z).sum(-1)).mean() + \
                    (self.config.use_action_weights*self.config.params.policy_const)*action_prob_loss
            
            logs = {
                'scalars': {
                    'norm': (z**2).sum(-1).mean(),
                    'smoothness': smoothness,
                    'recons_loss': recons_loss,
                    'inverse_loss': inverse_loss,
                    'forward_loss': forward_loss,
                    'policy_loss': action_prob_loss,
                    'loss': loss,
                },
            }
            return loss, logs
        
        
        def e_loss(
                model,
                z,
                actions,
                next_z,
                next_obs,
                action_probs,
                dones,
                recons_next_x,
                rng,
                importance_weights,
            ):

            # model = nnx.merge(*model)
            z = z.reshape(-1, z.shape[-1])
            next_z = next_z.reshape(-1, z.shape[-1])
            actions = actions.reshape(-1)
            
            # next_z_pred = nnx.vmap(
            #     lambda z, a, rng: model.predict(z, a, rng=rng),
            #     in_axes=(0, 0, 0)
            # )(sg(z), actions, jax.random.split(rng, z.shape[0]))

            rng_init, rng = jax.random.split(rng)
            delta_z_pred = jax.random.normal(rng_init, z.shape)

            steps=10
            alpha=1.

            
            def _langevin_step(model, delta_z, rng):
                
                delta_z = delta_z.reshape(-1, self.n_factors, self.n_vars)
                score_input = jnp.concatenate([delta_z, jnp.tile(sg(z)[:, None], (1, self.n_factors, 1))], axis=-1)
                model = nnx.merge(*model)
                score = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,1))(model.energies, score_input).swapaxes(0,1)
                noise = jax.random.normal(rng, delta_z.shape)
                score_a = jnp.take_along_axis(score, actions[:, None, None], axis=-1)
                delta_z = delta_z + alpha * score_a + jnp.sqrt(2*alpha) * noise
                return delta_z.reshape(-1, self.latent_dim), delta_z.reshape(-1, self.latent_dim)

            delta_z_pred, _ = jax.lax.scan(
                lambda delta_z, rng: _langevin_step(model, delta_z, rng),
                delta_z_pred,
                jax.random.split(rng, steps)
            )

            # delta_z_pred, _ = langevin_step(delta_z_pred, rng)
            
            prediction_loss = (optax.l2_loss(sg(next_z-z), delta_z_pred).sum(-1) * (1-dones)).mean()

            logs = {
                'scalars': {
                    'prediction_loss': prediction_loss
                }
            }
            
            return prediction_loss, logs


        dummy_logs = {
                'scalars': {
                        'norm': 0.,
                        'smoothness': 0.,
                        'recons_loss': 0.,
                        'inverse_loss': 0.,
                        'forward_loss': 0.,
                        'policy_loss': 0.,
                        'loss': 0.,
                    },
            }
        
        return (m_loss, e_loss), dummy_logs

# class DiscreteACFRepresentation(RepModel):

#     def __init__(self, config, rngs):
#         super().__init__(config, rngs)
#         self.config = config

#         self.rngs = rngs
#         self.latent_dim = self.config.latent_dim
#         self.n_factors = self.latent_dim // self.config.vars_per_factor
#         self.n_vars = self.config.vars_per_factor
#         self.batch_size = self.config.batch_size
#         if self.config.is_pixel:
#             self.config.encoder_pixel.input_shape = self.config.obs_dim
#             self.config.decoder_pixel.output_shape = self.config.obs_dim
#             self.config.decoder_pixel.input_dim = self.config.latent_dim
#         else:
#             self.config.encoder.input_dim = self.config.obs_dim[0]
#             self.config.decoder.output_dim = self.config.obs_dim[0]
#             self.config.decoder.input_dim = self.config.latent_dim
#         self.encoder = build_model(self.config.encoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.encoder, rngs)
#         self.decoder = build_model(self.config.decoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.decoder, rngs)
#         self.config.projector.input_dim = self.latent_dim
#         self.config.projector.output_dim = self.latent_dim
#         self.projector = build_model(self.config.projector, rngs)
#         self.inverse = build_model(self.config.inverse, rngs)

#         self.config.energy.input_dim = self.config.latent_dim
#         self.config.pi.input_dim = self.config.latent_dim

#         self.energies = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.energy, rngs, jnp.arange(self.n_factors))
#         self.pi = build_model(self.config.pi, rngs)

#     def __call__(self, x):
#         return self.quantize(self.encoder(x))

#     def get_energies(
#             self,
#             z,
#             next_z
#     ):
#         '''
#             Non batched energy computation
#             vmap it to work with batches.
#         '''
#         energies = nnx.vmap(lambda energy, input: energy(input).reshape(self.config.n_values, self.config.n_actions), in_axes=(0,None))(self.energies, self.projector(z))
#         energies = nnx.log_softmax(energies, axis=1)
#         idx = ((next_z+1) * (self.config.n_values-1)/2).astype(jnp.int32)
#         return energies[jnp.arange(energies.shape[0]), idx] #(F, V, A)

#     def quantize(self, z):
#         z_q = (z + 1) * (self.config.n_values - 1)/2
#         z_q = sg(2 / (self.config.n_values - 1) * jnp.round(z_q) - 1) # [0,1]
#         return z_q + z - sg(z)

#     def forward(
#             self,
#             obs,
#             actions,
#             rewards,
#             dones,
#             rng,
#             states=None
#         ):

#         rng_z, rng_next_z, rng_sample = jax.random.split(rng, 3)
#         z = self.quantize(self.encoder(obs))
#         next_z = z[..., 1:, :]
#         z = z[..., :-1, :]
        
#         if len(z.shape) == 1: # add batch dim
#             z = z[None]
#             next_z = next_z[None]

        
#         action_probs = nnx.log_softmax(self.pi(z))

#         batch_dims, latent_dim = z.shape[:-1], z.shape[-1]
#         batch_size = self.batch_size
#         _z = z.reshape(-1, latent_dim)
#         _next_z = next_z.reshape(-1, latent_dim)
#         _actions = actions.reshape(-1)
#         dones = dones.reshape(-1)
#         mask = 1-dones

#         idx = jax.random.categorical(
#             rng_sample,
#             jnp.log((1-dones)/(1-dones).sum() + 1e-12), # don't sample done elements because z->z' is invalid
#             axis=0,
#             shape=(batch_size,)
#         )

#         _z = _z[idx]
#         _next_z = _next_z[idx]
#         _actions = _actions.reshape(-1)[idx]
#         recons_next_x = self.decoder(_next_z)
       
#         # randomize the z' to estimate the forward dynamics.
#         energies = nnx.vmap(
#             nnx.vmap(
#                 self.get_energies,
#                 in_axes=(0,None)
#             ),
#             in_axes=(None, 0)
#         )(_z, _next_z)
#         # energies = nnx.vmap(self.get_energies, in_axes=(0,0))(_z, _next_z)
#         next_obs = next_obs.reshape(-1, *next_obs.shape[2:])[idx]
#         forward_energies = nnx.vmap(self.get_energies, in_axes=(0,0))(sg(_z), sg(_next_z))
#         return (z, actions, next_z), (next_obs, _actions, energies, forward_energies, recons_next_x, action_probs, idx, mask)

#     def loss_fn(self):
#         def _loss(
#                 z,
#                 actions,
#                 next_z,
#                 next_obs,
#                 sel_actions,
#                 energies,
#                 forward_energies,
#                 recons_next_x,
#                 action_probs,
#                 indices,
#                 mask,
#                 importance_weights
#             ):
#             batch_dims = z.shape[:-1]
#             importance_weights = importance_weights[:, None].repeat(batch_dims[-1], axis=-1)
#             recons_loss = jnp.mean(optax.l2_loss(next_obs, recons_next_x).reshape(*batch_dims, -1).mean(-1) * importance_weights)

#             action_prob_loss = jnp.mean(
#                     optax.softmax_cross_entropy_with_integer_labels(
#                     action_probs,
#                     actions.astype(jnp.int32)
#                 ) * importance_weights
#             )

#             smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()
#             _importance_weights = importance_weights.reshape(-1)[indices]

#             # r_energies = energies.sum(1) # energies real samples (B, A)
#             r_energies = jnp.diagonal(energies.sum(2)).transpose(1,0)
#             if not self.config.use_action_weights:
#                 # global_inverse = (r_energies - r_energies[..., :1])
#                 global_inverse = r_energies[jnp.arange(r_energies.shape[0]), sel_actions.reshape(-1)][..., None] - r_energies #(B, A)
#                 print(global_inverse.shape)
#             else:
#                 _action_probs = action_probs.reshape(-1, action_probs.shape[-1])[indices]
#                 weights = jax.lax.stop_gradient(-(_action_probs - _action_probs[..., :1]))
#                 global_inverse = ((r_energies - r_energies[..., :1]).sum(1) + weights)# B, A
                
#             inverse_loss = -(nnx.log_sigmoid(global_inverse).mean(-1)).mean()
#             r_energies = energies.sum(2)
#             # forward_loss = -jnp.mean((
#             #                     jnp.diagonal(r_energies).transpose(1,0) - \
#             #                     jax.nn.logsumexp(r_energies, axis=0)
#             #                 ).take(
#             #                     sel_actions.astype(jnp.int32),
#             #                     axis=-1
#             #                 ) * _importance_weights)
#             forward_loss = 0.
#             # forward_loss = -forward_energies.sum(1)[jnp.arange(forward_energies.shape[0]), sel_actions].mean()
#             loss = self.config.params.recons_const*recons_loss + \
#                     self.config.params.inverse_const*inverse_loss + \
#                     self.config.params.forward_const*forward_loss + \
#                     (self.config.use_action_weights*self.config.params.policy_const)*action_prob_loss #+ \
                    
            
#             logs = {
#                 'scalars': {
#                     'norm': (z**2).sum(-1).mean(),
#                     'smoothness': smoothness,
#                     'recons_loss': recons_loss,
#                     'inverse_loss': inverse_loss,
#                     'forward_loss': forward_loss,
#                     'policy_loss': action_prob_loss,
#                     'loss': loss,
#                 },
#             }
#             return loss, logs
        
#         dummy_logs = {
#                 'scalars': {
#                         'norm': 0.,
#                         'smoothness': 0.,
#                         'recons_loss': 0.,
#                         'inverse_loss': 0.,
#                         'forward_loss': 0.,
#                         'policy_loss': 0.,
#                         'loss': 0.,
#                     },
#                     # 'histograms': {
#                     #     **{'action_weights/action_{i}': 0. for i in range(self.config.n_actions)}
#                     # }
#             }
        
#         return _loss, dummy_logs

# class DetACFRepresentation(RepModel):

#     def __init__(self, config, rngs):
#         self.config = config
#         self.rngs = rngs
#         self.latent_dim = self.config.latent_dim
#         self.n_factors = self.latent_dim // self.config.vars_per_factor
#         self.n_vars = self.config.vars_per_factor
#         self.batch_size = self.config.batch_size
#         if self.config.is_pixel:
#             self.config.encoder_pixel.input_shape = self.config.obs_dim
#             self.config.decoder_pixel.output_shape = self.config.obs_dim
#             self.config.decoder_pixel.input_dim = self.config.latent_dim
#         else:
#             self.config.encoder.input_dim = self.config.obs_dim[0]
#             self.config.decoder.output_dim = self.config.obs_dim[0]
#             self.config.decoder.input_dim = self.config.latent_dim
#         self.encoder = build_model(self.config.encoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.encoder, rngs)
#         self.decoder = build_model(self.config.decoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.decoder, rngs)
#         self.slow_encoder =  build_model(self.config.encoder_pixel, rngs) if self.config.is_pixel else build_model(self.config.encoder, rngs)
#         graphdef, params, _ = nnx.split(self.encoder, nnx.Param, ...)
#         graphdef, _, rest = nnx.split(self.slow_encoder, nnx.Param, ...)
#         self.slow_encoder = nnx.merge(graphdef, jax.tree.map(jnp.copy, params), rest)

#         self.min_z = jnp.zeros(self.config.latent_dim)
#         self.max_z = jnp.ones(self.config.latent_dim)
       
#         self.config.dynamics.input_dim = self.config.latent_dim
#         self.config.pi.input_dim = self.config.latent_dim

#         self.dynamics = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.dynamics, rngs, jnp.arange(self.n_factors))
#         self.pi = build_model(self.config.pi, rngs)
    
#     def __call__(self, states):
#         return self.encoder(states)

#     def get_energies(
#             self,
#             z,
#             next_z,
#             sigma=0.1
#     ):
#         '''
#             Non batched energy computation
#             vmap it to work with batches.
#         '''
#         delta_z = (next_z-z).reshape(self.n_factors, self.n_vars)
#         delta_z_mean = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,None))(self.dynamics, z) # (F, A*N_VARS)
#         delta_z_mean = delta_z_mean.reshape(self.n_factors, self.n_vars, -1) # (F, V, A)
#         energies = -((delta_z[..., None] - delta_z_mean) ** 2).sum(1) / (2*sigma**2)
#         return energies # (F, A)

#     def predict(
#             self,
#             z,
#             next_z,
#             sigma=0.1
#     ):
#         '''
#             Non batched energy computation
#             vmap it to work with batches.
#         '''
#         delta_z = (next_z-sg(z)).reshape(self.n_factors, self.n_vars)
#         delta_z_mean = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,None))(self.dynamics, z) # (F, A*N_VARS)
#         delta_z_mean = delta_z_mean.reshape(self.n_factors, self.n_vars, -1) # (F, V, A)
#         energies = -((delta_z[..., None] - delta_z_mean) ** 2).sum(1) / (2*sigma**2)
#         return energies # (F, A)

#     def normalize(self, z):
#         # jax.debug.print("{}", (self.min_z, self.max_z))
#         # return (z - sg(self.min_z)) / jnp.maximum(1e-3, sg(self.max_z-self.min_z))
#         return z
    
#     def forward(
#             self,
#             obs,
#             actions,
#             rewards,
#             dones,
#             rng,
#             states=None
#         ):

#         rng_z, rng_next_z, rng_sample, rng_next_z_slow, rng_z_slow = jax.random.split(rng, 5)
#         z = self.encoder(obs)
#         z = z + jax.random.normal(rng_z, z.shape)*self.config.noise_std
#         next_z = z[..., 1:, :]
#         z = z[..., :-1, :]
#         next_obs = obs[:, 1:]
#         obs = obs[:, :-1]

#         slow_next_z = self.slow_encoder(next_obs) + jax.random.normal(rng_next_z_slow, z.shape)*self.config.noise_std
#         slow_z = self.slow_encoder(obs) + jax.random.normal(rng_z_slow, z.shape)*self.config.noise_std
       
       
#         if len(z.shape) == 1: # add batch dim
#             z = z[None]
#             next_z = next_z[None]

#         recons_next_x = self.decoder(self.normalize(next_z))
#         action_probs = nnx.log_softmax(self.pi(z))
#         mask = 1-dones

#         batch_dims, latent_dim = z.shape[:-1], z.shape[-1]
#         batch_size = self.batch_size
#         _z = z.reshape(-1, latent_dim)
#         _next_z = next_z.reshape(-1, latent_dim)
#         _actions = actions.reshape(-1)
#         dones = dones.reshape(-1)

#         idx = jax.random.categorical(
#             rng_sample,
#             jnp.log((1-dones)/(1-dones).sum() + 1e-12), # don't sample done elements because z->z' is invalid
#             axis=0,
#             shape=(batch_size,)
#         )

#         _z = self.normalize(_z[idx])
#         _next_z = self.normalize(_next_z[idx])
#         _actions = _actions.reshape(-1)[idx]
        
#         # randomize the z' to estimate the forward dynamics.
#         # energies = nnx.vmap(
#         #     nnx.vmap(
#         #         self.get_energies,
#         #         in_axes=(0,None)
#         #     ),
#         #     in_axes=(None, 0)
#         # )(_z, _next_z)

#         energies = nnx.vmap(
#                 self.get_energies,
#                 in_axes=(0,0)
#             )(sg(_z), _next_z)

#         slow_z = sg(self.normalize(slow_z.reshape(-1, self.latent_dim)[idx]))
#         slow_next_z = sg(self.normalize(slow_next_z.reshape(-1, self.latent_dim)[idx]))


#         forward_energies = nnx.vmap(
#                 self.predict,
#                 in_axes=(0,0)
#             )(sg(_z), sg(_next_z))

#         return (z, actions, next_z), (next_obs, _actions, energies, forward_energies, recons_next_x, action_probs, mask, sg(z))

#     def loss_fn(self):
#         def _loss(
#                 z,
#                 actions,
#                 next_z,
#                 next_obs,
#                 sel_actions,
#                 energies,
#                 forward_energies,
#                 recons_next_x,
#                 action_probs,
#                 mask,
#                 unnorm_z,
#                 importance_weights
#             ):

#             recons_loss = optax.l2_loss(next_obs, recons_next_x).mean()
#             action_prob_loss = optax.softmax_cross_entropy_with_integer_labels(action_probs, actions.astype(jnp.int32)).mean()
#             smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()

#             # r_energies = jnp.diagonal(energies).transpose(2, 0, 1) # energies real samples
#             r_energies = energies.sum(1)
#             if not self.config.use_action_weights:
#                 global_inverse = r_energies[jnp.arange(sel_actions.shape[0]), sel_actions][..., None] - r_energies # (B, A)
#             else:
#                 weights = jax.lax.stop_gradient(-(action_probs - action_probs[..., :1]))
#                 global_inverse = ((r_energies - r_energies[..., :1]).sum(1) + weights)# B, A
            
#             inverse_loss = -nnx.log_sigmoid(global_inverse).mean(-1).mean()

#             # r_energies = energies.sum(2)
#             # forward_loss = 0.5 * (
#             #                 -(
#             #                     jnp.diagonal(r_energies).transpose(1,0) - \
#             #                     jax.nn.logsumexp(r_energies, axis=0)
#             #                 ).take(
#             #                     sel_actions.astype(jnp.int32),
#             #                     axis=-1
#             #                 ).mean() -(
#             #                     jnp.diagonal(r_energies).transpose(1,0) - \
#             #                     jax.nn.logsumexp(r_energies, axis=1)
#             #                 ).take(
#             #                     sel_actions.astype(jnp.int32),
#             #                     axis=-1
#             #                 ).mean()
#             #             )
#             # forward_loss = -jnp.diagonal(r_energies).transpose(1,0)[jnp.arange(sel_actions.shape[0]), sel_actions].mean()# -log T(z'|z,a)
#             forward_loss = -forward_energies.sum(1)[jnp.arange(sel_actions.shape[0]), sel_actions].mean()

#             loss = self.config.params.recons_const*recons_loss + \
#                     self.config.params.inverse_const*inverse_loss + \
#                     self.config.params.forward_const*forward_loss + \
#                     (self.config.use_action_weights*self.config.params.policy_const)*action_prob_loss
            
#             logs = {
#                 'scalars': {
#                     'norm': (z**2).sum(-1).mean(),
#                     'smoothness': smoothness,
#                     'recons_loss': recons_loss,
#                     'inverse_loss': inverse_loss,
#                     'forward_loss': forward_loss,
#                     'policy_loss': action_prob_loss,
#                     'loss': loss,
#                 },
#             }
#             return loss, logs
        
#         dummy_logs = {
#                 'scalars': {
#                         'norm': 0.,
#                         'smoothness': 0.,
#                         'recons_loss': 0.,
#                         'inverse_loss': 0.,
#                         'forward_loss': 0.,
#                         'policy_loss': 0.,
#                         'loss': 0.,
#                     },
#             }
        
#         return _loss, dummy_logs

#     def update_slow_params(self, z):
#         z = z.reshape(-1, z.shape[-1])
#         z_min, z_max = jnp.percentile(z, jnp.array([5, 95]), axis=0)

#         self.min_z, self.max_z = optax.incremental_update(
#             (z_min, z_max),
#             (self.min_z, self.max_z),
#             0.01
#         )
#         nnx.update(
#             self.slow_encoder,
#             optax.incremental_update(
#                 nnx.state(self.encoder, nnx.Param),
#                 nnx.state(self.slow_encoder, nnx.Param),
#                 0.02
#             )
#         )

@chex.dataclass(frozen=True)
class RecurrentState:
    z : chex.Array
    h : chex.Array

class RecurrentACFRepresentation(RepModel):
    def __init__(self, config, rngs, external_encoder=None):
        super().__init__(config, rngs)
        self.config = config
        self.external_encoder = external_encoder

        self.rngs = rngs
        self.latent_dim = self.config.latent_dim
        self.n_factors = self.latent_dim // self.config.vars_per_factor
        self.n_vars = self.config.vars_per_factor
        self.batch_size = self.config.batch_size
        self.recurrent = True
        self.n_actions = self.config.n_actions

        if not self.config.get('use_ground_truth_states', False):
            if self.config.is_pixel and external_encoder is None:
                self.config.encoder_pixel.input_shape = self.config.obs_dim
                self.config.decoder_pixel.output_shape = self.config.obs_dim
                self.config.decoder_pixel.input_dim = self.config.latent_dim + self.config.hidden_dim
                self.embed_obs = build_model(self.config.encoder_pixel, rngs)
                self.decoder = build_model(self.config.decoder_pixel, rngs)
            else:
                # When external_encoder is provided, use non-pixel encoder
                # even if observations are pixels
                if self.config.is_pixel:
                    self.config.decoder_pixel.output_shape = self.config.obs_dim
                    self.config.decoder_pixel.input_dim = self.config.latent_dim
                    self.decoder = build_model(self.config.decoder_pixel, rngs)
                else:
                    self.config.decoder.output_dim = self.config.obs_dim[0]
                    self.config.decoder.input_dim = self.config.latent_dim + self.config.hidden_dim
                    self.decoder = build_model(self.config.decoder, rngs)
                # For the encoder, always use the non-pixel one when external_encoder is provided
                if external_encoder is not None:
                    # The input dimension depends on what the external encoder outputs
                    # Assuming the external encoder's output matches our latent dimension requirements
                    self.config.encoder.input_dim = self.config.latent_dim
                else:
                    self.config.encoder.input_dim = self.config.obs_dim[0]
                
                self.embed_obs = build_model(self.config.encoder, rngs)

        self.config.posterior.input_dim = self.config.hidden_dim + self.config.hidden_dim # embed + memory
        self.config.posterior.output_dim = self.config.latent_dim
        self.posterior = build_model(self.config.posterior, rngs)

        self.config.energy.input_dim = self.config.hidden_dim + self.config.vars_per_factor
        self.config.energy.output_dim = 1

        self.gru = nnx.vmap(build_model, in_axes=(None, None, 0))(self.config.memory, rngs, jnp.arange(self.n_actions))
        
        self.energies = nnx.vmap(
            build_model,
            in_axes=(None, None, 0)
        )(self.config.energy, rngs, jnp.arange(self.n_factors))

        self.config.pi.input_dim = self.config.hidden_dim + self.config.latent_dim
        self.pi = build_model(self.config.pi, rngs)

    def get_energies(
            self,
            next_z,
            h
    ):
        _next_z = next_z.reshape(self.n_factors, self.n_vars)
        energy_inputs = jnp.concatenate([_next_z, jnp.tile(h[None], (_next_z.shape[0], 1))], axis=-1)
        energies = nnx.vmap(lambda energy, input: energy(input), in_axes=(0,0))(self.energies, energy_inputs)
        return energies # (F,)
    
    def encode(
            self,
            obs,
            actions,
            dones,
            states=None
        ):
        '''
            Encode trajectory of length (B, T)
        '''
        if self.config.get('use_ground_truth_states', False):
            if states is None:
                raise ValueError('States must be provided if use_ground_truth_states is True')
            return states
        else:
            batch = obs.shape[0]
            embed_obs = self.embed_obs(obs)
            h0 = jnp.zeros((batch, self.config.hidden_dim,))
            z0 = self.posterior(jnp.concatenate([embed_obs[:,0], h0], axis=-1))
            s0 = RecurrentState(
                z=z0,
                h=h0
            )
            def _scan_step(gru, recurrent_state, input):
                embed_obs, action, is_first = input
                batch = embed_obs.shape[0]
                h0 = RecurrentState(
                        z=jnp.zeros((batch, self.config.latent_dim)),
                        h=jnp.zeros((batch, self.config.hidden_dim))
                )
                recurrent_state = jax.tree.map(lambda x, y: jnp.where(
                    is_first[:, None],
                    x,
                    y
                ), h0, recurrent_state)
               
                # TODO index gru by action instead of computing it for all actions
                h_a = nnx.vmap(
                    lambda _gru: _gru(recurrent_state.z, recurrent_state.h)[0], in_axes=(0,)
                )(nnx.merge(*gru))
                h_a = jnp.transpose(h_a, (1, 2, 0))
                h = jnp.take_along_axis(h_a, action[:, None, None], axis=-1)[..., 0]
                z = self.posterior(jnp.concatenate([embed_obs, h], axis=-1))
                recurrent_state = RecurrentState(z=z, h=h)
                return recurrent_state, recurrent_state
            
            is_first = jnp.concatenate([jnp.zeros((batch, 1)), dones[:, :-1]], axis=-1)
            (_, _), ss = jax.lax.scan(
                partial(_scan_step, nnx.split(self.gru)),
                s0,
                jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), (embed_obs[:, 1:], actions, is_first))
            )
            ss = ss.replace(
                z=jnp.concatenate([z0[None], ss.z], axis=0)
            )
            ss = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), ss)
            return ss.z

    def observe(
            self,
            embed_obs,
            action,
            is_first,
            recurrent_state
    ):
        '''
            Observe single transition
        '''
        batch = embed_obs.shape[0]
        h0 = RecurrentState(
                z=jnp.zeros((batch, self.config.latent_dim)),
                h=jnp.zeros((batch, self.config.hidden_dim))
            )
        recurrent_state = jax.tree.map(lambda x, y: jnp.where(
            is_first[:, None],
            x,
            y
        ), h0, recurrent_state)

        h_a = nnx.vmap(
            lambda gru: gru(recurrent_state.z, recurrent_state.h)[0], in_axes=(0,)
        )(self.gru)
        h_a = jnp.transpose(h_a, (1, 2, 0))
        h = jnp.take_along_axis(h_a, action[:, None, None], axis=-1)[..., 0]
        z = self.posterior(jnp.concatenate([embed_obs, h], axis=-1))
        return RecurrentState(z=z, h=h)
        
    def infer_actions(
            self,
            obs,
            actions,
            dones,
            states=None,
            rng=jax.random.PRNGKey(0)
    ):
        '''
            Infer action from single transition
            vmap it to work with batches.
            obs: (T, D)
        '''
        def _scan_step(carry, input):
            z, h, is_first = carry
            embed_obs, action, done = input
            # update memory
            h_action = nnx.vmap(
                lambda gru: gru(z, h)[0],
                in_axes=(0, 0)
            )(self.gru, z, h)
            h_action = h_action.permute(1, 2, 0) # (B, H, A)

            h = jax.lax.cond(
                is_first,   
                lambda: jnp.zeros_like(h), 
                lambda: h_action[..., action]
            )
            
            # update z
            next_z = self.posterior(
                jnp.concatenate([embed_obs, h], axis=-1)
            )
            carry = (next_z, h, done==1)

            return carry, (next_z, h_action)
        
        # encode first state
        embed_obs = self.embed_obs(obs)
        h0 = jnp.zeros((self.config.hidden_dim, ))
        z0 = self.posterior(jnp.concatenate([embed_obs[0], h0], axis=-1))
        
        _, (next_zs, h_actions) = jax.lax.scan(
            _scan_step,
            (z0, h0, False),
            (embed_obs[1:], actions, dones)
        )

        energies = nnx.vmap(
            nnx.vmap(
                self.get_energies,
                in_axes=(0, 0)
            ),
            in_axes=(None, 2),
            out_axes=2
        )(next_zs, h_actions) # (T-1, F, A)

        actions_probs = jax.nn.softmax(energies) # full inverse
        binary_classifiers = jax.nn.sigmoid(
            energies[:, 1:] - energies[:, :1]
        ) # binary classifiers
        return actions_probs, binary_classifiers

    def preprocess_obs(self, obs, rng):
        obs = obs + jax.random.normal(rng, obs.shape) *  5/255
        return obs
    
    def forward(
            self,
            obs,
            actions,
            rewards,
            dones,
            rng,
            states=None
        ):

        rng, rng_noise, rng_obs = jax.random.split(rng, 3)
        
        obs = self.preprocess_obs(obs, rng_obs)
        embed_obs = self.embed_obs(obs)

        # initialize memory
        h = jnp.zeros((embed_obs.shape[0], self.config.hidden_dim))
        z0 = self.posterior(jnp.concatenate([embed_obs[:, 0], h], axis=-1))
        rng_scan, rng_noise = jax.random.split(rng_noise, 2)
        z0 = z0 + jax.random.normal(rng_noise, z0.shape) * self.config.noise_std


        def _scan_step(gru, carry, input):
            z, h, is_first, rng = carry
            rng, rng_noise = jax.random.split(rng)
            gru = nnx.merge(*gru)
            embed_obs, action, reward, done = input
            # update memory
            h_action = nnx.vmap(
                lambda _gru: _gru(z, h)[0],
                in_axes=(0,)
            )(gru)
            h_action = jnp.transpose(h_action, (1, 2, 0))
            h = jnp.where(
                is_first[:, None],   
                jnp.zeros_like(h), 
                jnp.take_along_axis(h_action, action[:, None, None], axis=-1)[..., 0]
            )
            # update zz
            next_z = self.posterior(jnp.concatenate([embed_obs, h], axis=-1))
            next_z = next_z + jax.random.normal(rng_noise, next_z.shape) * self.config.noise_std
            carry = (next_z, h, done==1, rng)
            
            return carry, (next_z, h_action, h)


        _, (zs, h_actions, h) = jax.lax.scan(
            partial(_scan_step, nnx.split(self.gru)),
            (z0, h, jnp.zeros_like(h[:, 0]).astype(jnp.bool_), rng_scan),
            jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), (embed_obs[:, 1:], actions, rewards, dones))
        )
        
        zs = jnp.swapaxes(zs, 0, 1) # (B, T+1, D)
        h_actions = jnp.swapaxes(h_actions, 0, 1) # (B, T, H, A)
        h = jnp.swapaxes(h, 0, 1) # (B, T, H)

        zs = jnp.concatenate([z0[:, None], zs], axis=1)

        action_probs = self.pi(jnp.concatenate([h, zs[:, :-1]], axis=-1)) # (B, T, A)
       
        next_zs = zs[:, 1:].reshape(-1, self.latent_dim) # (B*T, D)
        _h_actions = h_actions.reshape(-1, self.config.memory.hidden_dim, self.config.n_actions) # (B*T, H, A)
        _actions = actions.reshape(-1)
        recons_next_x = self.decoder(
            jnp.concatenate(
                [
                    next_zs, 
                    jnp.take_along_axis(_h_actions, _actions[..., None, None], axis=-1)[..., 0]
                ], axis=-1)
        ) # (B*T, D)

        # compute energies for all actions
        @partial(jax.vmap, in_axes=(None, 1), out_axes=1)
        def _energy(z, h):
            energies = nnx.vmap(
                lambda energy, z_i: energy(jnp.concatenate([z_i, h], axis=-1))[..., 0],
                in_axes=(0, 0)
            )(self.energies, z)
            return energies
        

        # compute energies for all actions and contrast them
        energies = jax.vmap(
                    jax.vmap(
                        _energy,
                        in_axes=(0, None)
                    ),
                    in_axes=(None, 0)
                )((next_zs-zs[:, :-1].reshape(-1, zs.shape[-1])).reshape(-1, self.n_factors, self.n_vars), _h_actions) # (B*T, B*T, F, A)


        next_obs = obs[:, 1:].reshape(-1, *self.config.obs_dim)
        mask = (dones == 0).reshape(-1)
        
        return (zs[:, :-1], actions, zs[:, 1:]), (next_obs, _actions, energies,recons_next_x, action_probs, mask, rng)

    def loss_fn(self):
        def _loss(
                z,
                actions,
                next_z,
                next_obs,
                sel_actions,
                energies,
                recons_next_x,
                action_probs,
                mask,
                rng,
                importance_weights,
            ):
            batch_dims = z.shape[:-1]
            importance_weights = importance_weights[:, None].repeat(batch_dims[-1], axis=-1)
            recons_loss = jnp.mean(optax.l2_loss(next_obs, recons_next_x).reshape(*batch_dims, -1).mean(-1))

            action_prob_loss = jnp.mean(
                    optax.softmax_cross_entropy_with_integer_labels(
                    action_probs,
                    actions.astype(jnp.int32)
                ) * importance_weights
            )

            smoothness = jnp.linalg.norm(next_z-z, axis=-1).mean()
            _importance_weights = importance_weights.reshape(-1)

            r_energies = jnp.diagonal(energies).transpose(2, 0, 1) # energies real samples
            inverse_model_loss = (optax.softmax_cross_entropy_with_integer_labels(
                r_energies.sum(1),
                sel_actions
            )).mean()

            if not self.config.use_action_weights:
                global_inverse = (r_energies - r_energies[..., :1]).sum(1)
            else:
                _action_probs = action_probs.reshape(-1, action_probs.shape[-1])
                weights = jax.lax.stop_gradient(-(_action_probs - _action_probs[..., :1]))
                global_inverse = ((r_energies - r_energies[..., :1]).sum(1) + weights)# B, A

            r_energies = energies.sum(2)
            if self.config.get('original_forward_loss', True):
                forward_loss = jnp.diagonal(r_energies).transpose(1,0) - \
                                    jax.nn.logsumexp(r_energies, axis=1)
                forward_loss = jnp.take_along_axis(forward_loss, sel_actions[:, None], axis=-1)[:, -1]
            else:
                # Energies are (z', z, a)
                real_energies = jnp.take_along_axis(jnp.diagonal(r_energies).transpose(1,0), sel_actions[:, None], axis=-1)[..., -1]
                denominator = jnp.take_along_axis(r_energies, sel_actions[None, :, None], axis=-1)
                denominator = jax.nn.logsumexp(denominator[..., -1], axis=-1)
                forward_loss = real_energies - denominator

            # per action info_nce loss
            # Compute per-action InfoNCE loss
            # For each action, select only the elements that coincide with that action
            per_action_forward_loss = 0
            real_energies = jnp.take_along_axis(jnp.diagonal(r_energies).transpose(1,0), sel_actions[:, None], axis=-1)[..., -1]
            denominator = jnp.take_along_axis(r_energies, sel_actions[None, :, None], axis=-1)[..., -1]
            for a in range(self.config.n_actions):
                # Create mask for samples with this action
                action_mask = (sel_actions == a)
                action_real_energies = real_energies
                action_denominator = jax.nn.logsumexp(jnp.where(action_mask[None], denominator, -1e12), axis=-1)
                action_mask_values = mask & action_mask
                
                # Compute InfoNCE loss for this action
                action_forward = action_real_energies - action_denominator
                action_loss = -jnp.sum(action_forward * action_mask_values) / (jnp.sum(action_mask_values) + 1e-8)
                
                per_action_forward_loss = per_action_forward_loss + action_loss
            
            # Store per-action losses for logging if needed
            per_action_forward_loss = per_action_forward_loss / self.config.n_actions

            # per action info_nce
            forward_loss = -jnp.mean(forward_loss * mask)

            action_mask = jax.nn.one_hot(sel_actions, self.config.n_actions) # BxAx1
            inverse_loss = (optax.sigmoid_binary_cross_entropy(global_inverse, action_mask) * mask[:, None]).mean()
            
            _z = z.reshape(-1, z.shape[-1])
            grounding_loss = -((_z[None] - _z[:, None]) ** 2).sum(-1) / 0.1**2
            grounding_loss = -(jnp.diagonal(grounding_loss) - jax.nn.logsumexp(grounding_loss, axis=-1))
            grounding_loss = jnp.mean(grounding_loss)


            loss = self.config.params.recons_const*recons_loss + \
                    self.config.params.inverse_const*inverse_loss + \
                    self.config.params.get('inverse_model_const', 0.)*inverse_model_loss + \
                    self.config.params.forward_const*forward_loss + \
                    (self.config.use_action_weights*self.config.params.policy_const)*action_prob_loss + \
                    self.config.get('grounding_const', 0.)*grounding_loss + \
                    self.config.params.get('per_action_forward_const', 0.) * per_action_forward_loss.mean()
            
            logs = {
                'scalars': {
                    'norm': (z**2).sum(-1).mean(),
                    'smoothness': smoothness,
                    'recons_loss': recons_loss,
                    'inverse_loss': inverse_loss,
                    'inverse_model_loss': inverse_model_loss,
                    'forward_loss': forward_loss,
                    'policy_loss': action_prob_loss,
                    'grounding_loss': grounding_loss,
                    'per_action_forward_loss': per_action_forward_loss.mean(),
                    'loss': loss,
                },
                    # 'histograms': {
                    #     **{'action_weights/action_{i}': weights[..., i] for i in range(self.config.n_actions)}
                    # }
            }
            return loss, logs
        
        dummy_logs = {
                'scalars': {
                        'norm': 0.,
                        'smoothness': 0.,
                        'recons_loss': 0.,
                        'inverse_loss': 0.,
                        'forward_loss': 0.,
                        'policy_loss': 0.,
                        'loss': 0.,
                    },
                    # 'histograms': {
                    #     **{'action_weights/action_{i}': 0. for i in range(self.config.n_actions)}
                    # }
            }
        
        return _loss, dummy_logs
