import jax
import flax.nnx as nnx
import jax.numpy as jnp
from jaxmodels_nnx import build_model
class RepModel(nnx.Module):

    def __init__(self, config, rngs=None):
        self.config = config
        self.input_shape = self.config.obs_dim if self.config.is_pixel else self.config.obs_dim[-1]
        self.rngs = rngs
        self.encoder = lambda obs: obs
        self.latent_dim = config.latent_dim
        self.discrete = False
    
    def __call__(self, obs):
        return self.encoder(obs)
    
    def forward(self, obs, actions, rewards, dones, rng, states=None):
        pass

    def loss_fn(self, prioritized=False):
        def _loss(*args, **kwargs):
            return 0., {}
        return _loss, {}
    
    def update_slow_params(self, *args, **kwargs):
        pass

    def encode(self, obs, states=None):
        return self.encoder(obs) 
    
    
    def preprocess(
            self,
            obs,
            rng,
            noise_std = 3 / 255
    ):
        return obs + jax.random.normal(rng, obs.shape) * noise_std

class EncoderRepresentation(RepModel):
    def __init__(self, config, rngs):
        super().__init__(config)
        if self.config.is_pixel:
            self.config.encoder_pixel.input_shape = self.input_shape
            self.encoder = build_model(self.config.encoder_pixel, rngs=rngs)
        else:
            self.config.encoder.input_dim = self.input_shape
            self.encoder = build_model(self.config.encoder, rngs=rngs)

    def forward(self, obs, actions, rewards, dones, rng, states=None):
        z = self.encoder(obs)
        next_z = z[:, 1:]
        z = z[:, :-1]
        return (z, actions, next_z), ()
     
    def loss_fn(self):
        return lambda *args, **kwargs: (0., {}), {}