# Multimodal Graph Meta-Transformer (MGMT)

MGMT is a PyTorch-based framework for **multimodal graph learning**, integrating **Graph Transformers**, **attention-based node selection**, and **meta-graph construction** to support both accurate prediction and interpretability.

---

![MGMT Architecture](figures/magnet.jpeg)

**Figure**: Architecture of the Multimodal Graph Meta-Transformer (MGMT).  
(A) Depth-Aware Graph Transformer layers process individual modalities, extracting super-nodes to form a meta-graph.  
(B) Core components of Graph Transformer layers and self-attention mechanism for node aggregation and representation learning.

---

## 🌐 Overview

MGMT performs **graph-level classification** over multiple modalities, each represented as a graph with node features and edges. Key innovations include:

- **Modality-Specific Graph Transformers**: Deep self-attention over nodes within each graph  
- **Depth-Aware Fusion**: Layer-wise confidence-weighted aggregation  
- **Super-node Extraction**: Node importance inferred from attention scores  
- **Meta-Graph Construction**: Combines intra- and inter-modality relationships via cosine similarity  
- **Final Graph-Level Predictor**: Operates on the fused meta-graph

---

## 📦 Project Structure

```text
src/
├── __init__.py                    # Package overview and module documentation
├── main.py                        # Entry point for running the pipeline
├── lfp_dataset.py                 # Real LFP graph preprocessing functions (data not included here)
├── dataset_simulated_setting1.py  # Synthetic graph generation (random noise)
├── dataset_simulated_setting2.py  # Synthetic graph generation (GP-based)
├── train.py                       # Final training and evaluation using best params
├── train_utils.py                 # Helper functions for loss, accuracy, weight update
├── eval_utils.py                  # Accuracy and confusion matrix computation
├── objective.py                   # Optuna objective for hyperparameter tuning
├── utils.py                       # Seed setting, normalization, init, etc.
```

---

## 📈 Evaluation

**Real Dataset (LFP):**
- Due to file size limitations, LFP data is not included in this repository.
- You can access the full version (including the LFP dataset) through the following anonymous link provided in the main paper:  
  👉 [https://anonymous.4open.science/r/new_submission-33A6](https://anonymous.4open.science/r/new_submission-33A6)
- Evaluation is performed using **5-Fold Cross-Validation** with Optuna-based hyperparameter tuning. Final performance is reported on the held-out test fold.

**Synthetic Settings (Setting 1 & 2):**
- We perform **n = 50 independent repetitions** with random train/val/test splits (no cross-validation).
- For each repetition:
  1. Randomly split the synthetic dataset.
  2. Train the model with Optuna hyperparameter search.
  3. Evaluate on the test set using the best validation model.
- Final performance is reported as the **mean and standard error (SE)** across 50 runs.

---

## 🧪 Datasets

* **Real Data** (not included): Balanced LFP graphs from 5 subjects (available [here](https://anonymous.4open.science/r/new_submission-33A6))
* **Synthetic Data**:
  * **Setting 1**: Multimodal graphs with modality-specific noise
  * **Setting 2**: Gaussian Process-based feature generation with shared label logic

---

## 🚀 Running the Code

### Requirements

Add the following to your `requirements.txt`:

```text
torch
torch-geometric
optuna
scikit-learn
numpy
GPy
```

### Usage

```bash
# Default: Use real LFP data (data not included in this repo)
python -m src.main

# Use simulated setting 1 (random noise graphs)
python -m src.main --dataset sim1

# Use simulated setting 2 (GP-based graphs)
python -m src.main --dataset sim2
```

💡 Note: Dataset switching is controlled in `main.py` via comments.

---

## 🧠 Interpretation Component

MGMT provides interpretability through:

* Node-level attention scores
* Thresholded super-node selection
* Inter-modality edge analysis in the meta-graph

These enable deeper insights into both the **feature-level** and **modal-level** structure of predictions.

---

