import jax
import flax.nnx as nnx

from reps.rep_model import EncoderRepresentation
from reps.acf import ACFRepresentation
from reps.contrastive import MarkovRepresentation, GCLRepresentation
from reps.disentangled import DMS

sg = lambda x: jax.tree.map(jax.lax.stop_gradient, x)

class RepresentationBuilder:

    REP_BUILDERS = {
        'base': EncoderRepresentation,
        'markov': MarkovRepresentation,
        'acf': ACFRepresentation,
        'gcl': GCLRepresentation,
        'dms': DMS,
    }

    REP_TO_ID = {k:i for i,k in enumerate(REP_BUILDERS.keys())}
    ID_TO_REP = {v:k for k,v in REP_TO_ID.items()}

    @staticmethod
    def build(rep_id, config, env_config, rngs=nnx.Rngs(0)):
        rep_type = RepresentationBuilder.ID_TO_REP[rep_id]
        builder = RepresentationBuilder.REP_BUILDERS[rep_type]
        config.reps.n_actions = int(env_config.n_actions)
        rep_config = config.reps.get(rep_type, {})
        assert len(env_config.obs.shape) == 3 or len(env_config.obs.shape) == 1
        rep_config.is_pixel = len(env_config.obs.shape) == 3
        rep_config.obs_dim = env_config.obs.shape
        rep_config.use_ground_truth_states = config.get('use_ground_truth_states', False)
        return builder(rep_config, rngs=rngs)
   
    @staticmethod
    def get_id(name: str):
        return RepresentationBuilder.REP_TO_ID[name]


