# Multimodal Graph Meta-Transformer (MGMT)

MGMT is a PyTorch-based framework for **multimodal graph learning**, integrating **Graph Transformers**, **attention-based node selection**, and **meta-graph construction** to support both accurate prediction and interpretability.

---

![MGMT Architecture](figures/magnet.jpeg)

**Figure**: Architecture of the Multimodal Graph Meta-Transformer (MGMT).  
(A) Depth-Aware Graph Transformer layers process individual modalities, extracting super-nodes to form a meta-graph.  
(B) Core components of Graph Transformer layers and self-attention mechanism for node aggregation and representation learning.

---

## 🌐 Overview

MGMT performs **graph-level classification** over multiple modalities, each represented as a graph with node features and edges. Key innovations include:

- **Modality-Specific Graph Transformers**: Deep self-attention over nodes within each graph  
- **Depth-Aware Fusion**: Layer-wise confidence-weighted aggregation  
- **Super-node Extraction**: Node importance inferred from attention scores  
- **Meta-Graph Construction**: Combines intra- and inter-modality relationships via cosine similarity  
- **Final Graph-Level Predictor**: Operates on the fused meta-graph

---

## 📦 Project Structure

```text
src/
├── __init__.py                   # Package overview and module documentation
├── main.py                       # Entry point for running the pipeline
├── lfp_dataset.py                # Real LFP graph data loading and preprocessing
├── dataset_simulated_setting1.py # Synthetic graph generation (random noise)
├── dataset_simulated_setting2.py # Synthetic graph generation (GP-based)
├── train.py                      # Final training and evaluation using best params
├── train_utils.py                # Helper functions for loss, accuracy, weight update
├── eval_utils.py                 # Accuracy and confusion matrix computation
├── objective.py                  # Optuna objective for hyperparameter tuning
├── utils.py                      # Seed setting, normalization, init, etc.
```

---

## 🔍 Methodological Foundations

### 1. Estimation Component

Each modality-specific graph $\mathcal{G}_i = (\mathcal{V}_i, \mathcal{E}_i)$ undergoes $L$ layers of Graph Transformer updates with self-attention:

$$
\bm{H}_{i}^{(\ell)} = \textnormal{LayerNorm}\left( \bm{Z}_i^{(\ell)} + \sigma(\textnormal{FFN}(\bm{Z}_i^{(\ell)})) \right)
$$

Final embeddings and attention weights are aggregated across depths:

$$
\bm{H}_i = \sum_{\ell} \Gamma^{(\ell)} \bm{H}_{i}^{(\ell)}, \quad \alpha_{i,uv} = \sum_{\ell} \Gamma^{(\ell)} \cdot \text{Attention}_{uv}^{(\ell)}
$$

### 2. Super-node Extraction

Nodes with high outgoing attention weights are retained as super-nodes:

$$
\mathcal{S}_i = \left\{ u \in \mathcal{V}_i \;\middle|\; \sum_{(u,v) \in \mathcal{E}_i} \alpha_{i,uv} \geq \tau \right\}
$$

A pruned graph $\mathcal{G}_i'$ is induced over these nodes.

### 3. Meta-Graph Construction

The meta-graph $\mathcal{G}_M = (\mathcal{S}_M, \mathcal{E}_M)$ is built from all super-nodes:

* **Intra-modality**: edges from each $\mathcal{G}_i'$
* **Inter-modality**: cosine similarity between nodes $u, v$ from different modalities:

$$
e_{uv} = \frac{\bm{H}_u^\top \bm{H}_v}{\|\bm{H}_u\| \|\bm{H}_v\|}, \quad \text{if } e_{uv} > \gamma \text{ add edge } (u,v)
$$

### 4. Final Prediction

Graph Transformers are applied to the meta-graph. The graph-level prediction is computed by:

$$
\hat{y} = f(\text{Pool}(\bm{H}_M))
$$

---

## 🚀 Running the Code

### Requirements

Add these to your `requirements.txt`:

```text
torch
torch-geometric
optuna
scikit-learn
numpy
GPy
```

### Usage

```bash
# Default: Use real LFP data
python -m src.main

# Use simulated setting 1 (random noise graphs)
python -m src.main --dataset sim1

# Use simulated setting 2 (GP-based graphs)
python -m src.main --dataset sim2
```

💡 Optional: You can switch datasets in `main.py` by commenting/uncommenting relevant lines.

---

## 📈 Evaluation

* 5-Fold Cross-Validation
* Optuna-based hyperparameter tuning on validation set
* Final evaluation on held-out test set
* Report the **mean and standard error (SE)** of the test accuracy over the 50 repetitions.
* For simulation we performed **n = 50 independent repetitions** no cross validation.

---

## 🧪 Datasets

* **Real Data**: Balanced LFP graphs from 5 rats in `lfp_data/`
* **Synthetic Data**:

  * **Setting 1**: Multimodal graphs with modality-specific noise
  * **Setting 2**: Gaussian Process-based feature generation with shared label logic

---

## 🧠 Interpretation Component

MGMT provides interpretability through:

* Node-level attention scores
* Thresholded super-node selection
* Inter-modality edge analysis in the meta-graph

These enable deeper insights into both the **feature-level** and **modal-level** structure of predictions.

---

