# Shapley Regularized Neural Granger Causality (SRNGC)

Official implementation of **Shapley Regularized Neural Granger Causality** - a novel framework for learning causal relationships in multivariate time series data using neural networks with Shapley value-based regularization.

## Overview

This repository implements a neural network approach to Granger causality that uses Shapley values to measure and regularize feature importance. The method:

- **Models complex nonlinear dynamics** using residual MLP architectures
- **Regularizes with Shapley values** to identify causal relationships
- **Supports multiple importance measures**: Shapley values (exact and fast approximations) and Jacobian-based methods
- **Handles diverse datasets**: Synthetic (VAR, Lorenz96) and real-world (DREAM3 gene regulatory networks, fMRI, CausalTime)

## Installation

### Prerequisites
- Python 3.7+
- PyTorch 1.12+
- CUDA 12.1+ (optional, for GPU acceleration)

### Setup Environment

```bash
# Clone the repository
git clone <repository-url>
cd SRNGC_code

# Install dependencies
bash conda_env.sh
```

Or install manually:
```bash
pip3 install torch torchvision --index-url https://mirrors.nju.edu.cn/pytorch/whl/cu121
pip3 install numpy pandas scipy matplotlib scikit-learn
```

## Project Structure

```
SRNGC_code/
├── src/                          # Main source code
│   ├── train.py                  # Core training loop with grid search
│   ├── simulation.py             # Run experiments on synthetic data
│   ├── real_data.py              # Run experiments on real-world datasets
│   ├── ablation.py               # Ablation studies (model architectures)
│   ├── sensitivity.py            # Sensitivity analysis
│   └── FDR.py                    # Knockoff-based FDR control
├── model/                        # Model implementations
│   ├── models.py                 # Neural network architectures (MLP, ResidualMLP, LSTM)
│   ├── penalty.py                # Shapley and Jacobian penalty functions
│   ├── imp_measure.py            # Importance measure computations
│   └── comp_efficiency.py        # Computational efficiency utilities
├── data/                         # Data loading and generation
│   ├── Data_generator.py         # Synthetic data generators (VAR, Lorenz96)
│   ├── Dataset.py                # Dataset classes and data loaders
│   └── DREAM3/                   # DREAM3 gene regulatory network data
├── utils/                        # Utility functions
│   ├── configs.py                # Hyperparameter configurations
│   ├── setups.py                 # Model and penalty setup utilities
│   └── cls_performance.py        # Performance metrics (AUROC, AUPRC)
├── assets/                       # Pre-computed assets
├── results/                      # Experiment results (auto-created)
├── conda_env.sh                  # Environment setup script
└── reproduce_real_data.sh        # Reproduction script for real data experiments
```

## Usage

### 1. Simulation Experiments (Synthetic Data)

Run experiments on synthetic datasets (VAR3, Lorenz96):

```bash
python src/simulation.py \
    --dataset VAR3 \
    --series 1 \
    --subject 1 \
    --lag 3 \
    --seed 2025 \
    --penalty_type Fast_Shap
```

**Arguments:**
- `--dataset`: Dataset name (`VAR3`, `Lorenz96`)
- `--series`: Series index
- `--subject`: Subject index (for multi-subject datasets)
- `--lag`: Time lag for autoregressive modeling
- `--seed`: Random seed
- `--penalty_type`: Regularization method (`Fast_Shap`, `Shapley`, `Jacob_F`, `Jacob_L1`)
- `--num_workers`: Number of parallel workers for hyperparameter search
- `--exec_idx`: Worker index (for parallel execution)

### 2. Real Data Experiments

Run experiments on real-world datasets (DREAM3, fMRI, CausalTime):

```bash
python src/real_data.py \
    --dataset DREAM3 \
    --series 1 \
    --seed 2025 \
    --penalty_type Fast_Shap
```

**Using pre-tuned hyperparameters:**
```bash
python src/real_data.py \
    --dataset DREAM3 \
    --series 1 \
    --seed 2025 \
    --penalty_type Fast_Shap \
    --use_best \
    --best_metric AUROC
```

**Arguments:**
- `--dataset`: Dataset name (`DREAM3`, `DREAM4`, `fMRI`, `CausalTime`)
- `--use_best`: Use pre-tuned hyperparameters instead of grid search
- `--best_metric`: Metric for selecting best hyperparameters (`AUROC`, `AUPRC`)

### 3. Reproduce Paper Results

To reproduce the real data results from the paper:

```bash
bash reproduce_real_data.sh
```

This script runs all DREAM3 experiments with pre-tuned hyperparameters for both AUROC and AUPRC metrics.
