# Measuring and Controlling Solution Degeneracy across Task-trained RNNs

This repository contains code for the paper “Measuring and Controlling Solution Degeneracy across Task-trained RNNs.”

## Experimental Workflow

Each experiment follows these steps:

1. **Train** RNNs until they reach an asymptotic loss threshold

2. **Compute dynamical, weight, and behavioral degeneracy** on the converged networks:  
   1. **Dynamical degeneracy** (dsa.py)
   2. **Weight degeneracy**  (weight_degeneracy.py)
   3. **Behavioral degeneracy**  (OOD_eval.py)
3. Compute additional measures (e.g. CCA)  
4. Plot figures in Python notebooks using utility functions provided in utils.py. For an example analysis on the network size effect on degeneracy, see analysis/example.ipynb.

## Example: Network Training
We use slurm array jobs for training 50 different seeds in parallel. Otherwise, replace $SLURM_ARRAY_TASK_ID with other variables specifying the seed to run. 
### N-BFF
```bash
python degeneracy/main.py \
  --task_name 3bff_{custom_folder_name} \
  --n_bits 3 \
  --early_stopping_threshold 0.001 \
  --epoch 300 \
  --seed $SLURM_ARRAY_TASK_ID
```

### Delayed Discrimination
```bash
python degeneracy/main.py \
  --task_name delayed_discrimination_{custom_folder_name} \
  --lr_scheduler cosine \
  --dd_max_delay 20 \
  --epoch 300 \
  --dim_input 1 \
  --dim_output 1 \
  --early_stopping_threshold 0.01 \
  --seed $SLURM_ARRAY_TASK_ID
```

### Sinewave Generation
```bash
python degeneracy/main.py \
  --task_name sinewave_{custom_folder_name} \
  --epoch 1000 \
  --early_stopping_threshold 0.05 \
  --seed $SLURM_ARRAY_TASK_ID
```

### Path Integration
```bash
python degeneracy/main.py \
  --task_name path_integration_{custom_folder_name} \
  --n_batch 64 \
  --epoch 1000 \
  --lr_scheduler reduce \
  --early_stopping_threshold 0.05 \
  --path_dim 2 \
  --seed $SLURM_ARRAY_TASK_ID
```

## Optional Flags
### Using muP parameterization without changing network size (for feature learning effect)
```bash
--muP_param \
--gamma {your_gamma_value}
```
### Using muP and changing network size
```bash
--muP_param \
--gamma 1 \
--n_hidden {new_network_size}
```
### Adding low-rank regularization
```bash
--W_rank_reg {penalty_strength}
```
### Adding L1 regularization
```bash
--W_l1_reg {penalty_strength}
```


## Example: Compute degeneracy metrics
### Dynamical degeneracy
```bash
python $HOME/Degeneracy/degeneracy/dsa.py \
  --task {task_name} \
  --n_networks 50 \
  --idx_network $SLURM_ARRAY_TASK_ID
```
### Weight degeneracy
```bash
python $HOME/Degeneracy/degeneracy/weight_degeneracy.py \
  --task {task_name} \
  --idx $SLURM_ARRAY_TASK_ID
```

### Behaviroal degeneracy
```bash
python $HOME/Degeneracy/degeneracy/eval_OOD.py \
  {all the same args you passed to main.py}
```

### (Optional) Representational degeneracy
```bash
python $HOME/Degeneracy/degeneracy/cca.py \
    --task {task_name} \
    --idx_network $SLURM_ARRAY_TASK_ID
```