import os

import orbax.checkpoint as ocp
import flax.nnx as nnx
from reps.jax_reps_nnx import RepresentationBuilder


def load_model(
        path,
        config,
        env_config,
        rng
):
    """Load the trained model from the specified path.
    
    Args:
        path (pathlib.Path): Path to the model checkpoint
        config: Configuration object containing model parameters
        env_config: Environment configuration object
        rng (jax.random.PRNGKey): Random number generator key
        
    Returns:
        nnx.Module: Loaded model with restored parameters
    """
    # Load the trained model from the specified path
    checkpointer = ocp.StandardCheckpointer()
    abs_path = os.path.abspath(path)
    rep = path.name.split('_')[-1]
    rep_model_builder = RepresentationBuilder.build(
        RepresentationBuilder.REP_TO_ID[rep],
        config,
        env_config,
        rngs=nnx.Rngs(rng)
    )
    graphdef, abstract_state, rest = nnx.split(rep_model_builder, nnx.Param, ...)
    restored = checkpointer.restore(abs_path, abstract_state)
    
    # Load the restored parameters into the model
    model = nnx.merge(graphdef, restored, rest)
    
    return model

def build_model(
        config,
        env_config,
        rng
    ):
    """Builds the representation model and returns the loss function.
    
    Args:
        config: Configuration object containing model parameters
        env_config: Environment configuration object
        rng (jax.random.PRNGKey): Random number generator key
        
    Returns:
        tuple: (rep_model, loss_fn) containing the built model and its loss function
    """
    # Important: Builds the representation model and returns the loss function
    rep_model = RepresentationBuilder.build(
        RepresentationBuilder.REP_TO_ID[config.rep],
        config,
        env_config,
        rngs=nnx.Rngs(rng)
    )
    loss_fn, dummy_logs = rep_model.loss_fn()
    return rep_model, loss_fn