# JEPA-Reasoner and Talker Model

This repository contains the complete implementation of the JEPA-Reasoner and Talker models described in our paper. Follow the instructions below to reproduce our experiments.

## Environment Setup

1. **GPU Requirements**: Ensure you have a CUDA-compatible device installed and recognized in your working environment.

2. **Install Dependencies**: Create and activate the conda environment:
   
   > Note: You must explicitly install jax with `pip install -U "jax[cuda12]"` to get a CUDA enabled version
   
   ```bash
   conda create -n jepa-reasoner python=3.13
   conda activate jepa-reasoner
   pip install -r requirements.txt
   pip install -U "jax[cuda12]"
   pip install flax
   ```

3. **Verify Installation**: Test that JAX can detect your CUDA device. Note that this codebase is designed for single CUDA device usage.
   
   ```python
   import jax
   print(jax.devices())
   # This should display your CUDA device.
   ```

## Project Structure

- `model.py`: Core model implementations (JEPA-Reasoner, Mono/Dual-Talker and Transformer baseline)
- `layers/`: Neural network layer implementations
- `./exp_utils/`: Tree/CFG generation and tokenization utilities
- `snapshot.py`: Model checkpointing utilities
- `./configs`: Model configuration json files
- `./data`: Datasets used to do experiment
- `dataloader.py`: Multipthread data loder class used to load natural language data
- `./results`: Plotting utils and benchmark results
- `./snapshot`: pretrained weights of JEPA-Reasoner, Talker and Transformer baseline models
- Training scripts: `pretrain.py`, `sst_*.py`, `train_*.py`
- Testing scripts: `test_*.py`, `benchmark_*.py`

## Tree Search Experiment

This experiment uses JSON configuration files to build models:

- `reasoner.json`: Configuration for the JEPA-Reasoner model
- `talker.json`: Configuration for the Mono-Talker model

### Usage

#### JEPA-Reasoner Pretraining

```bash
python pretrain_tree.py
```

#### JEPA-Reasoner SST

```bash
python sst_reasoner_tree.py
```

#### Training the Talker

```bash
python train_talker_tree.py
```

#### Testing

For examine mix-latent behavior and visualize the result:

```bash
python test_mixed_latent.py
python plot_2d_latent_space.py
```

For testing end-to-end the tree-search ability:

```bash
python test_tree-search.py
```

## CFG Experiments

> If you don't want to train models from scratch, jump to "Evaluating Models"

### Training Models from Scratch

#### **Update MLflow Tracking URL**

Before you executing any training script, please open it in your editor and change the place holder to your MLflow tracking url

#### **Generate CFG Dataset**

Navigate to the experiment utilities directory and generate CFG sequences:

```bash
   cd ./exp_utils
   python cfg_generator.py
```

This will create two training datasets and one test dataset. The generated data is saved to `./data` by default.

#### **Pretrain the Transformer Models**

To obtain the weights used to initialize other models, run:

```bash
   python train_transformer_cfg_pretrain.py small
   python train_transformer_cfg_pretrain.py middle
   python train_transformer_cfg_pretrain.py large
```

#### **Posttrain the Transformer Models**

Use posttrain script to obtain $T_{small}$, $T_{middle}$ and $T_{large}$ used in the token level error test.

```bash
   # Note: when type the pretrained model's snapshot, don't include ".safetensors".
   # For example, if the name of snapshot is "snapshot/cfg_completion/transformer/cfg_pretrain_large.safetensors"
   # Just type snapshot/cfg_completion/transformer/cfg_pretrain_large
   python train_transformer_cfg_posttrain.py small [PATH_TO_SMALL_PRETRAINED_TRANSFORMER]
   python train_transformer_cfg_posttrain.py middle [PATH_TO_MIDDLE_PRETRAINED_TRANSFORMER]
   python train_transformer_cfg_posttrain.py large [PATH_TO_LARGE_PRETRAINED_TRANSFORMER]
```

#### **Obtain the Initial Weights For JEPA-Reasoner**

Use `weight_adapter.py` to split the weights of pretrained Transformer model and obtain the initial weight of JEPA-Reasoner

```bash
   python weight_adapter.py small [PARENT_PATH_OF_PRETRAINED_TANSFORMER] [NAME_OF_SNAPSHOT]
   python weight_adapter.py middle [PARENT_PATH_OF_PRETRAINED_TANSFORMER] [NAME_OF_SNAPSHOT]
   python weight_adapter.py large [PARENT_PATH_OF_PRETRAINED_TANSFORMER] [NAME_OF_SNAPSHOT]
   # Example: python weight_adapter.py small snapshot/cfg_completion/transformer cfg_pretrain_small
```

This will give you `init_small.safetensors`, `init_middle.safetensors` and `init_large.safetensors` in `./snapshot/cfg_completion/reasoner/`

#### **Train the JEPA-Reasoner Models**

Use SST script to train JEPA-Reasoner models.

```bash
   python sst_reasoner_cfg.py small
   python sst_reasoner_cfg.py middle
   python sst_reasoner_cfg.py large
```

#### **Train the COCONUT Model**

Use pretrained Transformer weights to train the COCONUT model:

```bash
   python train_transformer_cfg_coco.py small [PATH_TO_PRETRAINED_SMALL_TRANSFORMER]
   python train_transformer_cfg_coco.py middle [PATH_TO_PRETRAINED_MIDDLE_TRANSFORMER]
   python train_transformer_cfg_coco.py large [PATH_TO_PRETRAINED_LARGE_TRANSFORMER]
```

### Token-Level Error Experiment

> If you train the models from scratch, please relace the path to existing weights with actual weights obtained.

#### **Evaluate JEPA-Reasoner and Talker Models**

The trained weights described in our paper are located in `./snapshot/cfg_completion/token_level_err_test`. To evaluate the accuracy of both JEPA-Reasoner and Talker models, run:

```bash
   python benchmark_reasoner_cfg.py small snapshot/cfg_completion/token_level_err_test/reasoner_cfg3f_small snapshot/cfg_completion/token_level_err_test/talker_cfg3f_small
   python benchmark_reasoner_cfg.py middle snapshot/cfg_completion/token_level_err_test/reasoner_cfg3f_middle snapshot/cfg_completion/token_level_err_test/talker_cfg3f_middle
   python benchmark_reasoner_cfg.py large snapshot/cfg_completion/token_level_err_test/reasoner_cfg3f_large snapshot/cfg_completion/token_level_err_test/talker_cfg3f_large
   # Evaluate all three model sizes
```

Each script will save the result to `./results`. 

#### **Evaluate Baseline Transformer Models**

The trained weights for CFG Transformers are located in `./snapshot/cfg_completion/token_level_err_test`. Run the evaluation scripts:

```bash
   python benchmark_transformer_cfg.py small snapshot/cfg_completion/token_level_err_test/cfg_posttrain_small
   python benchmark_transformer_cfg.py middle snapshot/cfg_completion/token_level_err_test/cfg_posttrain_middle
   python benchmark_transformer_cfg.py large snapshot/cfg_completion/token_level_err_test/cfg_posttrain_large
   # Evaluate all three baseline model sizes
```

Each script will save the result to `./results`

#### **Generate Performance Plots**

Append the content of 6 `.txt` files generated above to `./results/all_results.txt`. Then run the plotting script to generate the relative performance visualization:

```bash
   cd ./results
   python plot.py
```

### Latent Space Error Experiment

> If you train the models from scratch, please relace the path to existing weights with actual weights obtained.

#### **Evaluate JEPA-Reasoner and Talker Models**

The trained weights described in our paper are located in `./snapshot/cfg_completion/latent_space_err_test`. To evaluate the accuracy of both JEPA-Reasoner and Talker models, run:

```bash
   python benchmark_reasoner_cfg_gaussian.py small snapshot/cfg_completion/latent_space_err_test/reasoner_cfg3f_small snapshot/cfg_completion/latent_space_err_test/talker_cfg3f_small
   python benchmark_reasoner_cfg_gaussian.py middle snapshot/cfg_completion/latent_space_err_test/reasoner_cfg3f_middle snapshot/cfg_completion/latent_space_err_test/talker_cfg3f_middle
   python benchmark_reasoner_cfg_gaussian.py large snapshot/cfg_completion/latent_space_err_test/reasoner_cfg3f_large snapshot/cfg_completion/latent_space_err_test/talker_cfg3f_large
   # Evaluate all three model sizes
```

#### **Evaluate COCONUT Transformer Models**

The trained weights for CFG Transformers are located in `./snapshot/cfg_completion/latent_space_err_test`. Run the evaluation scripts:

```bash
   python benchmark_coco_cfg_gaussian.py small snapshot/cfg_completion/latent_space_err_test/cfg_posttrain_small
   python benchmark_transformer_cfg.py middle snapshot/cfg_completion/latent_space_err_test/cfg_posttrain_middle
   python benchmark_transformer_cfg.py large snapshot/cfg_completion/latent_space_err_test/cfg_posttrain_large
   # Evaluate all three baseline model sizes
```

## Talker Ablation Study

Run the script:

```bash
   python talker_ablation_study.py
```