

This directory contains a minimal, self-contained set of scripts and modules for training and evaluating representation learning models using JAX, Flax, and related libraries.

## Directory Structure

- `train.py`: Script to train representation models.
- `eval.py`: Script to evaluate trained models and generate metrics and visualizations.
- `jaxmodels_nnx.py`: Definition of neural network architectures.
- `permutation_optimizer.py`: Utilities for permuting and evaluating factorization matrices.
- `policies/`: Policy definitions for various environments.
- `envs/`: Environment builders and wrappers (e.g., Gymnax, Navix, Taxi).
- `datasets/`: Functions and classes for generating and loading datasets.
- `reps/`: Representation building and evaluation modules.
- `utils/`: Helper functions for logging, checkpointing, state processing, and more.

## Installation

1. Create a Python virtual environment (Python 3.9):
   ```bash
   python3 -m venv venv
   source venv/bin/activate
   ```

2. Install core dependencies:
   ```bash
   pip install -r requirements-minimal.txt
   ```

3. (Optional) Install environment-specific dependencies:
   ```bash
   pip install gym gymnax navix minigrid opencv-python
   ```

> **Note**: For GPU support with JAX, refer to the [JAX installation instructions](https://github.com/google/jax#installation) to install the appropriate version of `jax` and `jaxlib`.

## Usage

### Training

Run training with a given configuration:
```bash
python train.py --config train_rep.yaml
```

Additional flags:
- `--outdir`: output directory for checkpoints and logs.
- `--seed`: random seed.
- For full options, run `python train.py --help`.


## Reproduce 

```bash
   python train.py --config ./configs/domain/method/config.yaml
```
