# Early learning of the optimal constant solution in neural networks and humans
This project holds relevant code for the Neurips2024 submission 
"Early learning of the optimal constant solution in neural networks and humans"

## Description

The repository contains code for plotting the main results in the paper.
It also contains code to run the main experiments reported in the paper.


## file structure
The file structure should be intuitive and should allow for reproducibility <br>

The folder `data/` should contain datasets used in the neural network experiments these need to be downloaded and we do not include them for space (such as MNIST, CIFAR-10). 
The folder `human-data/` contains the experimental results from human subjects. It has separate folders for the main cohort of subjects in `data/train-data-all-online/` and for the online replication in `data/train-data-all-online/`. The human data is saved in .npy files for ease of plotting.
<br>

Further, `notebooks/` contains `.ipynb` files used for plotting. `results/model_runs_linear` holds the results of the network training runs. `scripts/` contains `.py` files used for the training of the models. `src/` is exclusively used for reusable modules for import in other files and notebooks.

```bash
├── README.md
├── data
│   └── train-data
│   └── train-data-all-online
├── environment.yml
├── notebooks
├── results
│   └── model_runs_linear
├── scripts
├── setup.py
├── src
│   ├── __init__.py
│   ├── ...
```

## Dependencies

* All analyses were created on Mac OS, dependencies are contained in environment.yml file. We used Jax for our neural network experiments. 

## Getting Started
The environment.yml file will install the cpu-only version of Jax. This is sufficient to run linear network experiments without issue and the experiments on should take a few minutes to run.

For experiments with CNNs it might be advantageous to have the Jax GPU version. Information for about the gpu installations of Jax can be found at https://github.com/google/jax#installation. As we are interested in the early stages of learning we run the CNNs for a relatively small number of epochs. Experiments with CNNs took less than 30 mins to complete with GPU. They also should run with no huge issues on locally on cpu.

1. Recreate conda environment:
   ```
   conda env create --file environment.yml
   ```
   Then run
   ```
   conda activate environment_ocs
   ```
2. Install code in `src/` as editable package (run this in the root folder):
   ```
   pip install -e .
   ```

## Usage
Relevant plots can be generated using the .ipynb files in the notebooks folder. <br>
To run experiments use:

To run the deep linear network experiment with bias terms run:
   ```
   python scripts/linear_net_training.py --include_bias_input --model_name "linear_net_bias"
   ```
To run the deep linear network experiment without bias term run:
   ```
   python scripts/linear_net_training.py --model_name "linear_net_no_bias"
   ```

To run the shallow linear network experiment without bias term run:
   ```
   python scripts/shallow_net_training.py --include_bias_input --model_name "shallow_net_bias"
   ```

To run the shallow linear network experiment without bias term run:
   ```
   python scripts/shallow_net_training.py --model_name "shallow_net_no_bias"
   ```

For the Hierarchical MNIST experiment run:

   ```
   python scripts/cnn_training_scripts/cnn_training_hieararchy.py --use_hidden_layer_bias
   ```
For the "orthogonalised" MNIST experiment run:
   ```
   python scripts/cnn_training_scripts/cnn_training_hieararchy.py --orthogonalise 

   ```

For the Hierarchical Cifar-10 experiment run:

   ```
   python scripts/cnn_training_scripts/cnn_training_hieararchy.py --use_hidden_layer_bias --dataset_name "cifar10"
   ```

We set all experiments to 3 runs only for convenience. To change this adapt the config files in`src/neural_nets/configs/` or use relevant command line arguments.

