# TiDeSPL-VAE

Codes for the paper ***Time-Dependent VAE for Building Latent Representations from Visual Neural Activity with Complex Dynamics***. 

## Pre-requisites

In order to run this project you will need:

- Python3
- PyTorch
- The following packages: numpy, tqdm, scikit-learn, cebra

## Model Training and Testing

The code is stored in the file folder `train`. It supports single GPU.

### Example for Synthetic Non-Temporal Data

We provide the preprocessed dataset, which could be downloaded from https://drive.google.com/file/d/1IIwakt2Ox87iTj9wLYOFtbZlV0r60gtM/view?usp=sharing.

Train:

```
python train_simulated_data.py --data-path Simulated_Data/ --data-dim 100 --classes 0 --version 3 --model-name tidespl_vae --latent-dim 32 --aug --epochs 500 --batch-size 320 --opt adam --lr 5e-4 --output-path logs/ --repeat 10 --device cuda:0
```

We provide an example of a checkpoint of TiDeSPL-VAE trained on the synthetic non-temporal dataset, which could be downloaded from https://drive.google.com/file/d/1A-kfW6--S9S68VywtTIn84OvW7zcyBqz/view?usp=sharing.

Test:

```
python test_simulated_data.py --data-path Simulated_Data/ --data-dim 100 --classes 0 --version 3 --model-name tidespl_vae --latent-dim 32 --checkpoint-path checkpoint/ --data-split all --batch-size 320 --output-path logs/ --repeat 1 --device cuda:0 --seed
```

### Example for Synthetic Temporal Data

We provide the preprocessed dataset, which could be downloaded from https://drive.google.com/drive/folders/143aEiXEcsZEFm2ZfZpSyveA1ko5miHTT?usp=sharing.

Train:

```
python train_lorenz_data.py --data-path Simulated_Data/ --conditions 5 --model-name tidespl_vae --latent-dim 8 --aug 5 --seq-len 50 --epochs 100 --batch-size 500 --opt adam --lr 1e-3 --output-path logs/ --repeat 5 --device cuda:0
```

We provide an example of a checkpoint of TiDeSPL-VAE trained on the synthetic temporal dataset, which could be downloaded from https://drive.google.com/file/d/17Gc97u0AHarmtfBdUFdQfMOjnAYKR7Yl/view?usp=sharing.

Test:

```
python test_lorenz_data.py --data-path Simulated_Data/ --conditions 5 --model-name tidespl_vae --latent-dim 8 --checkpoint-path checkpoint/ --data-split test --batch-size 500 --output-path logs/ --repeat 1 --device cuda:0 --seed
```

### Example for Mouse Neural Data under Natural Scenes

We provide the preprocessed dataset of Mouse 1, which could be downloaded from https://drive.google.com/file/d/1V3xlJsPzekfTva9p9snN_sfGmaURIxQu/view?usp=sharing.

Train:

```
python train_mouse_scenes.py --data-path neural_dataset/ --stimulus allen_natural_scenes --time-step 25 --mouse-id 16 --classes 5 --model-name tidespl_vae --latent-dim 128 --aug 3 --seq-len 5 --epochs 250 --batch-size 250 --opt adam --lr 1e-4 --output-path logs/ --repeat 10 --device cuda:0
```

We provide an example of a checkpoint of TiDeSPL-VAE trained on the mouse neural dataset of Mouse 1, which could be downloaded from https://drive.google.com/file/d/1lTFjBaJ1bplp5pf3_0aNY1WN2yeutYpZ/view?usp=sharing.

Test:

```
python test_mouse_scenes.py --data-path neural_dataset/ --stimulus allen_natural_scenes --time-step 25 --mouse-id 16 --classes 5 --model-name tidespl_vae --latent-dim 128 --checkpoint-path checkpoint/ --seq-len 5 --data-split test --batch-size 250 --output-path logs/ --repeat 1 --device cuda:0 --seed
```

### Example for Mouse Neural Data under Natural Movie

We provide the preprocessed dataset of Mouse 2, which could be downloaded from https://drive.google.com/file/d/1T3cqgckpr9Vu5_60WhlRsgxzzlHqtqsD/view?usp=sharing.

Train:

```
python train_mouse_movie.py --data-path neural_dataset/ --stimulus allen_natural_movie_one --time-step 4 --mouse-id 17 --model-name tidespl_vae --latent-dim 128 --aug 2 --seq-len 4 --epochs 200 --batch-size 288 --opt adam --lr 1e-4 --output-path logs/ --repeat 10 --device cuda:0
```

We provide an example of a checkpoint of TiDeSPL-VAE trained on the mouse neural dataset of Mouse 2, which could be downloaded from https://drive.google.com/file/d/13d82SKgPiASy3hr-I_Hw5VukcE4rdOam/view?usp=sharing.

Test:

```
python test_mouse_movie.py --data-path neural_dataset/ --stimulus allen_natural_movie_one --time-step 4 --mouse-id 17 --model-name tidespl_vae --latent-dim 128 --checkpoint-path checkpoint/ --seq-len 4 --data-split test --batch-size 288 --output-path logs/ --repeat 1 --device cuda:0 --seed
```
