
# Discrete Distributions are Effective Neural Network Outputs for Event Prediction

This repository is the official implementation of [Discrete Distributions are Effective Neural Network Outputs for Event Prediction](anonymized). 

## Requirements

To install requirements:

```setup
pip install -r requirements.txt

pip install -e ./kdai
pip install -e ./kdtpp
```

Depending on what you try to run, you might also need:
```
pip install -r requirements_devel.txt
```

Alternatively, using Docker:

```setup
docker build -t tppimage --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) ./
```


## Data
To download the data, run:

```data
wget "https://pub-d74474ef66bb47838654ed06c6166aa7.r2.dev/cyclic.tar.gz"
wget "https://pub-d74474ef66bb47838654ed06c6166aa7.r2.dev/rand_process.tar.gz"
wget "https://pub-d74474ef66bb47838654ed06c6166aa7.r2.dev/metropolis.tar.gz"
wget "https://pub-d74474ef66bb47838654ed06c6166aa7.r2.dev/nyc_taxi_2013.tar.gz"
wget "https://pub-d74474ef66bb47838654ed06c6166aa7.r2.dev/chicken_2021.tar.gz"
wget "https://pub-d74474ef66bb47838654ed06c6166aa7.r2.dev/baseline.tar.gz"
```


And extract:

```data
tar -xvzf cyclic.tar.gz -C ./data
tar -xvzf rand_process.tar.gz -C ./data
tar -xvzf metropolis.tar.gz -C ./data
tar -xvzf nyc_taxi_2013.tar.gz -C ./data
tar -xvzf chicken_2021.tar.gz -C ./data
tar -xvzf baseline.tar.gz -C ./data
```

The instructions hereafter assume the following directory structure:

```
data/chicken_2021/dataset.json
data/rand_process/dataset.json
data/cyclic/10dim_1obj_1024event.npz
data/cyclic/9dim_1obj_1024event.npz
data/cyclic/8dim_1obj_1024event.npz
data/cyclic/7dim_1obj_1024event.npz
data/cyclic/6dim_1obj_1024event.npz
data/cyclic/5dim_1obj_1024event.npz
data/cyclic/4dim_1obj_1024event.npz
data/cyclic/3dim_1obj_1024event.npz
data/cyclic/2dim_1obj_1024event.npz
data/cyclic/1dim_1obj_1024event.npz
data/nyc_taxi_2013/taxi_dataset/split_details.json
data/nyc_taxi_2013/taxi_dataset/test.json
data/nyc_taxi_2013/taxi_dataset/train.json
data/nyc_taxi_2013/taxi_dataset/val.json
data/metropolis/events.npy
data/baseline/lastfm.pkl
data/baseline/pubg.pkl
data/baseline/reddit_askscience_comments.pkl
data/baseline/reddit_politics_submissions.pkl
data/baseline/twitter.pkl
data/baseline/wikipedia.npz
data/baseline/yelp_airport.pkl
data/baseline/yelp_mississauga.pkl
data/baseline/yelp_toronto.npz
data/baseline/mooc.npz
data/baseline/easytpp/amazon/test.json
data/baseline/easytpp/amazon/train.json
data/baseline/easytpp/amazon/val.json
data/baseline/easytpp/taobao/test.json
data/baseline/easytpp/taobao/train.json
data/baseline/easytpp/taobao/val.json
data/lrmap/combined_lrmap.csv
```



## Experiments

The experiments are organised into four groups:

 - `./versioned_runs/0` for the real-world datasets (Section 3)
 - `./versioned_runs/1` for the synthetic and real-world datasets at multiple training lengths. (Section 4 and 5)
 - `./versioned_runs/2` for the RGC spike prediction task (Section 6)
 - `./versioned_runs/3` for the modulo datasets (Section 7)

For `./versioned_runs/0`, there are four parts:

  - 0/0: `lrsweep.py`. Carries out a learning rate sweep.
  - 0/1: `lrcalc.py`.  Chooses learning rates based on sweeps.
  - 0/2: `train.py`. Train all model-dataset combinations.
  - 0/3: `eval.py`. Evaluate all trained models.

Optionally, the lr-sweep and calc can be skipped by using the provided
learning rate map (which is set by default). It's the same situation for
`./versioned_runs/1` and `./versioned_runs/3`. For `./versioned_runs/2`, all
models use lr=5e-4 and don't require lr-sweep or calc.

Each of the parts can be run like:

```run
python ./versioned_runs/0/0/lrsweep.py
python ./versioned_runs/0/1/lrcalc.py # Requires output from lrsweep
python ./versioned_runs/0/2/train.py
python ./versioned_runs/0/3/eval.py # Requires output from train
```
The scripts report their output directory, and these paths need to be used as
input to the next step. 

The output of the `eval.py` scripts are dataframes containing the results that
are used to generate the paper figures.

## Stack Overflow 
To use the Stack Overflow dataset, a data dump must be requested from [Stack 
Overflow](https://stackoverflow.com/help/data-dumps). The instructions above
avoid using the Stack Overflow dataset.

## Code organisation
All entry points are in `kdtpp/experiments.py`. Models are in `kdtpp/models.py`. Models are wrapped by 
Trainables, which live in `kdtpp/trainables.py`. Datasets and dataloaders are in `kdtpp/datasets.py`. 
Spike prediction related code lives in `kdtpp/disttrainable.py`, `kdtpp/mea.py` and `kdtpp/inferspikes.py`.
The `kdai` package contains general training and evaluation code used by but not specific to the project.

## Tests
There are tests, but some of them won't work as some test resources have
been excluded from the archive in order to reduce the archive size.


## License
[BSD-3 clause](https://opensource.org/license/bsd-3-clause)

