# Continuous Multinomial Logistic Regression (CMLR) for Neural Decoding

## Introduction

This repository is the official implementation of **Continuous Multinomial Logistic Regression (CMLR) model**. It is implemented in [Python 3.12.8](https://www.python.org/downloads/release/python-3128/), using the [PyTorch 2.5.1](https://pytorch.org/) framework. 

**Motivation:** Multinomial logistic regression (MLR) is a classic model for multi-class classification that has been widely used for neural decoding. However, MLR requires a finite set of discrete output classes, limiting its applicability to settings with continuous-valued outputs (e.g., time, orientation, velocity, or spatial position). To address this limitation, we propose Continuous Multinomial Logistic Regression (CMLR), a generalization of MLR to continuous output spaces. CMLR defines a novel exponential-family model for conditional density estimation (CDE), mapping neural population activity to a full probability density over external covariates. It captures the influence of each neuron’s activity on the decoded variable through a smooth, interpretable tuning function, regularized by a Gaussian process prior. The resulting nonparametric decoding model flexibly captures a wide variety of conditional densities, including multimodal, asymmetric, and circular distributions.

 **Inference:** We introduce a [PyTorch](https://pytorch.org/)-based efficient stochastic variational inference procedure in Frequency Domain for model training, allowing CMLR to scale to high-dimensional datasets.

 **Results:** We evaluated the model's performance using held-out data on a variety of neural decoding tasks with continuous sensory or motor output variables, and found that it consistently outperformed a variety of other decoding methods, including Naive Bayes (NB), deep neural networks (DNN) and Extreme Gradient Boosting (XGBoost), and FlexCode, a leading nonparametric CDE method.

**Outline of the CMLR model:** Given the input feature vector $\mathbf{x} \in \mathbb{R}^D$ and a continuous output $y \in \Omega$, the model defines a multinomial logistic density over output space $y$ using decoding weight functions $\mathbf{w}(y) = \left\{w_d(y)\right\}_{d=1}^D$. Thus, the probability density for each possible output value $y$ given input $\mathbf{x}$ is proportional to $\exp(\mathbf{w}(y)\top \mathbf{x})$. The weight function $w_d(y)$ for each feature $d$ is assigned an independent Gaussian Process (GP) prior, with a radial basis function (RBF) covariance function parameterized by marginal variance $\rho_d$ and lengthscale $\ell_d$. The goal of inference is to jointly learn the $D$-dimensional weight function $\mathbf{w}(y)$ and the GP hyperparameters $\theta = \left\{ \rho_{d}, \ell_{d}\right\}_{d=1}^D$ from observed input-output pairs $\left\{\mathbf{x}_{n},y_n\right\}_{n=1}^N$.

![Structure of the CMLR model](attachment:image.png)

## What is included

### Methods:
- **CMLR**: Our proposed Continuous Multinomial Logistic Regression (CMLR) model and inference framework
- **Naive Bayes**:  A baseline model widely used in neural decoding, which assumes conditional independence across features. For a fair comparison, we introduce a version that places continuous Gaussian Process (GP) priors on the decoding weights for each feature and performs inference in the Frequency Domain, similar to the CMLR model.
- **DNN & XGBoost**: Additional baselines adopted directly from the [Neural Decoding Package](https://github.com/KordingLab/Neural_Decoding).

### Results: 
End‐to‐end, ready-to-run Jupyter notebooks demonstrate:
- Simulation studies showing that our proposed stochastic variational inference framework can accurately recover both parameters and hyperparameters in 1D and 2D ouputs spaces 
- Applications of our model to real neural datasets, including decoding grating orientations (1D) mouse primary visual cortex (V1) (data from [Stringer et al., 2021](https://pubmed.ncbi.nlm.nih.gov/33857423/)) and 2D velocities monkey motor cortex (MC) (data from [Glaser et al., 2020](https://www.eneuro.org/content/7/4/ENEURO.0506-19.2020)).
- Superior decoding performance of CMLR compared to the baseline models across both datasets.

#### Usage Examples via Notebooks

| Notebook                                    | Purpose                                                                                                          |
| ------------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
| **1D_simulation_validation_CMLR.ipynb**     | • A simulation example that shows that the CMLR inference framework faithfully recovers the underlying hyperparameters and decoding weights in 1-dimensional output spaces (Figure S1 in the manuscript) - this notebook demonstrates how to train the CMLR model for 1D outputs             |
| **2D_simulation_validation_CMLR.ipynb**     | • A simulation example that shows the CMLR inference framework faithfully recovers the underlying hyperparameters and decoding weights in 2-dimensional output spaces (Figure S6 in the manuscript) - this notebook demonstrates how to train the CMLR model for 2D outputs              |
| **1D_real_data_mouse_V1_CMLR.ipynb**          | • A step-by-step guide on how to apply the CMLR on 1 dimensional real data problems: specifically, on data ([Stringer et al., 2021](https://pubmed.ncbi.nlm.nih.gov/33857423/)) recorded from the mouse primary auditory cortex (V1) under drifting grating stimuli (Figure 2 in the manuscript) - this notebook demonstrates how to train the CMLR model for 1D outputs and use it for decoding held-out data  |
| **2D_real_data_monkey_MC_CMLR.ipynb**          | • A step-by-step guide on how to apply the CMLR on 2-dimensional real data problems: specifically, on data ([Glaser et al., 2020](https://www.eneuro.org/content/7/4/ENEURO.0506-19.2020)) recorded from the monkey motor cortex (V1) while performing a target-reaching task (Figure 4 in the manuscript)  - this notebook demonstrates how to train the CMLR model for 2D outputs and use it for decoding held-out data |
| **1D_real_data_mouse_V1_all_models.ipynb**          | • A comprehensive implementation and performance comparison of all decoders on a 1D orientation decoding task, applied on data ([Stringer et al., 2021](https://pubmed.ncbi.nlm.nih.gov/33857423/)) recorded from the mouse primary auditory cortex (V1) under drifting grating stimuli (Figure 2 in the manuscript) |
| **2D_real_data_monkey_MC_all_models.ipynb**          | • A comprehensive implementation and performance comparison of all decoders on a 2D velocity decoding task, applied on data ([Glaser et al., 2020](https://www.eneuro.org/content/7/4/ENEURO.0506-19.2020)) recorded from the monkey motor cortex (V1) while performing a target-reaching task (Figure 4 in the manuscript) |


## Quick start


1. **Make a local copy** of this folder.

2. **Create & activate** the conda environment:
   ```bash
   conda env create -f environment.yml
   conda activate cmlr-env

   or

3. **Alternative (pip):**     

    python3 -m venv venv

    source venv/bin/activate      # On Windows: venv\Scripts\activate

    pip install -r requirements.txt

    pip install -e .

3. **Launch JupyterLab**    

    make notebook

    or

    jupyter lab

4. **Open a notebook** in the notebooks/ folder and run the cells

### Key dependencies

- python >= 3.12.8
- numpy >= 2.1.3
- pandas >= 2.2.3
- scikit-learn >= 1.6.1
- matplotlib >= 3.10.0
- pytorch >= 2.5.1 
- xgboost >= 3.0.0
- Neural-Decoding == 0.1.5

    All required packages can be installed using the provided environment.yml file.


## Repository Structure

```text
.
├── CMLR/                                   # Core CMLR model (training, decoding, plotting helpers)
│   ├── __init__.py
│   ├── CMLR_model.py                       # The CMLR model
│   ├── CMLR_train.py                       # fit_CMLR_model + training loops
│   ├── CMLR_decode.py                      # CMLR decode function
│   ├── plot_real_data_results.py           # Plotting the real data results
│   ├── plot_simulation_validations.py      # Plotting the simulation validations
│   └── simulate_data.py                    # Simulate observations from the CMLR model
├── NaiveBayes/                             # Naive Bayes GP baseline
│   ├── __init__.py
│   ├── Naive_Bayes_model.py                # The Naive_Bayes model (training, decoding, plotting helpers)
│   ├── Naive_Bayes_train.py                # fit_Naive_Bayes_model + training loops
│   ├── Naive_Bayes_decode.py               # Naive_Bayes decode function
│   ├── plot_real_data_results_NB.py        # Plotting the real data results for Naive_Bayes
│   ├── plot_simulation_validations_NB.py   # Plotting the simulation validations for Naive_Bayes
│   └── simulate_data_NB.py                 # Simulate observations from the Naive_Bayes model
├── Neural_Decoding/                        # DNN & XGBoost baselines (from the Neural_Decoding package)
│   ├── __init__.py
│   └── decoders.py                         # simple wrappers for PyTorch DNN, XGBoost
├── utils/                                  # Shared utilities
│   ├── fourier_tools.py                    # The tools used for the Frequency Domain implementation
│   ├── other_plot_functions.py             # Other shared plot functions
│   └── other_utils.py                      # other shared utils
├── data/                                   # Raw and processed data scripts
│   ├── monkeyMCdata/                       # monkey M1 recordings + labels
│   └── mouseV1data/                        # sample mouse V1 recordings + labels
├── notebooks/                              # example Jupyter notebooks demonstrating simulation, training, decoding, and analysis
│   ├── 1D_simulation_validation_CMLR.ipynb
│   ├── 2D_simulation_validation_CMLR.ipynb
│   ├── 1D_real_data_mouse_V1_CMLR.ipynb
│   ├── 2D_real_data_monkey_MC_CMLR.ipynb
│   ├── 1D_real_data_mouse_V1_all_models.ipynb
│   └── 2D_real_data_monkey_MC_all_models.ipynb
├── requirements.txt            # pip dependencies
├── environment.yml             # conda environment (alternative)
├── Makefile                    # Automate env/data/test/fig generation
├── LICENSE                     # MIT License
└── README.md                   # This file

