Install jax and flax normally with `pip install jax[cuda]` and `pip install flax`.  
Make sure that the numpy version matches.
Go to flaxmodels and `pip install -e .`
Install `pip install transformers==4.54.0`
Make sure all of the versions match up.
---


# Soft QD Optimization with SQUAD

This repository contains the implementation of the models and experiments described in our submission.  
It is written in Python and builds on top of [JAX](https://github.com/google/jax), [Flax](https://github.com/google/flax).

The code is structured to allow reproduction of the reported experiments as well as straightforward adaptation to other tasks and datasets.

## Installation

We recommend using a clean Python 3.10 virtual environment.

1. Install JAX with CUDA support:
```bash
pip install "jax[cuda]"
````
2. Install Flax:
```bash
pip install flax
```
3. Install the local `flaxmodels` library (a modified version included alongside this repo):
```bash
cd flaxmodels
pip install -e .
```
5. Install the required Transformers version (used only in LSI experiments):
```bash
pip install transformers==4.54.0
```
6. Install the remaining Python dependencies:
```bash
pip install -r requirements.txt
```


## Repository Structure

* `src/`
  Core source code including model definitions, training loops, evaluation utilities, and logging.

* `configs/`
  YAML configuration files for experiments used by hydra.

* `scripts/`
  Helper scripts to run training and evaluation with different configurations.

* `flaxmodels/`
  Local modified version of the `flaxmodels` library that is required for the LSI experiments.


## Running New Experiments

To run training with a specific configuration:
```bash
python -m src.main --config-name <ALGO_CONFIG> task=<TASK_CONFIG> [PARAM_OVERRIDES]
```
where 
- <ALGO_CONFIG>: The algorithm's configuration file (e.g., config.yaml for SoftQD, cma_mae.yaml for CMA-MAE).
- <TASK_CONFIG>: The task's configuration file (e.g., image_rendering.yaml).
- [PARAM_OVERRIDES]: Any Hydra parameter overrides (e.g., seed=123).

## Reproducing Main Experiments
We provide scripts to reproduce the main experimental results from the paper.

1. Image Rendering Task
```bash
./scripts/baselines_rendering.sh
```
This will run SoftQD (SQUAD), CMA-MAE (Sep-CMA-MAE), CMA-MEGA, CMA-MAEGA, and PGA-ME (GA-ME) for 10 seeds each.

2. Latent Space Illumination (LSI) Task
```bash
./scripts/baselines_lsi.sh
```
3. Rastrigin Task
```python
scripts/baselines_rastrigin.py
```
This would run mutliple experiments in parallel to save time.

## Evaluation
After running an experiment, you can compute QD metrics using the `src/evaluate.py` script. The shell scripts in `scripts/` already include an evaluation loop that processes all runs in the `outputs/` directory.

To evaluate a single run manually:
```bash
python -m src.evaluate "path/to/your/output/directory"
```

This will create an `evaluation_results/` subdirectory containing `metrics.json` and plots of the solution archive.

## Extending the Codebase
The framework is designed to be extensible for new algorithms and tasks.

### Adding a New Task
1. Create a new file in `src/tasks/`.
2. Implement it as a class that conforms to the `src.tasks.base.Task` protocol. This requires defining properties like `solution_size` and `descriptor_dim`, and implementing an `evaluate` method.
3. The `evaluate` method should return an `EvalOutput` tuple containing fitnesses, descriptors, and their respective gradients.
4. Create a corresponding configuration file `configs/task/my_task.yaml`.
4. Update the task factory function in `src/tasks/utils.py` to include your new task.

### Adding a New Algorithm
1. Create a new file in `src/qd/`.
2. Implement the main training loop for your algorithm. This function should accept a Hydra config object and a `Task` instance.
3. Add a new configuration file in `configs/`.
4. Update `src/main.py` to add a condition to call your new algorithm's training loop when `algo_name` matches.

## Citation
If you use this code in your research, please consider citing our paper:  
TODO

