# 🧬 SupGCL: Supervised Graph Contrastive Learning for Gene Regulatory Network

This repository contains implementations of various graph learning methods for gene regulatory networks, specifically focusing on TCGA (The Cancer Genome Atlas) data.

## 🤖 What is SupGCL?
**SupGCL (Supervised Graph Contrastive Learning)** is a novel graph contrastive learning framework for gene regulatory networks (GRNs). Unlike traditional GCL methods that rely on artificial perturbations (e.g., node dropping, edge perturbation), SupGCL leverages biologically meaningful perturbations — gene knockdown experiments — as the supervision.

**🔑 Key features:**
- **🧫 Biological supervision:** Incorporates real gene knockdown data to guide contrastive learning, aligning virtual perturbations with actual experimental effects.  
- **🎲 Mathematical extension:** Extends existing GCL objectives to a probabilistic model that integrates biological perturbations directly into the learning process.
 


## Get Dataset Here!

The large-scale dataset used in this study is hosted at the following link.  
Please download it as needed.

[📂 Download the dataset]
Anonymous. (2025). Gene Regulatory Networks of LINCS and TCGA [Data set]. Zenodo. https://doi.org/10.5281/zenodo.15496012

## 🚀 Implemented Methods

1. **GAE (Graph AutoEncoder)**
   - Implementation of a graph autoencoder for unsupervised learning of gene regulatory networks
   - Uses reconstruction loss for node and edge features

2. **GraphCL (Graph Contrastive Learning)**
   - **Graph-level contrastive method:** maximizes agreement between representations of the same graph under different augmentations  
   - Implements virtual node knockdown for data augmentation
   - Uses NT-Xent loss for contrastive learning

3. **GRACE (Graph Contrastive Learning with Adaptive Sampling)**
   - **Node‐level contrastive method:** maximizes agreement between embeddings of the same node across views, using all other nodes (intra- and inter-view) as negatives
   - Implements symmetric node contrast loss
   - Uses virtual node knockdown for augmentation

4. **SGRL (Scattering Graph Representation Learning)**
   - **Representation Scattering Mechanism (RSM):** pushes node embeddings away from a learned center to encourage global uniformity  
   - **Topology-based Constraint Mechanism (TCM):** uses multi-hop neighborhood aggregation to preserve local graph structure during scattering  
   - Dual-encoder design (online & target networks) with EMA updates  
   - Eliminates need for manual negative sampling by combining RSM and TCM into a unified self-supervised framework  


5. **SupGCL (Supervised Graph Contrastive Learning)**
   - **Supervised contrastive learning approach**
   - **Dual-level objectives:**  
     - **Node-level:** aligns embeddings of the same gene under patient-specific vs. knockdown graphs  
     - **Augmentation-level:** minimizes distributional gap between virtual and real perturbations
   -  Integrates LINCS knockdown data for supervision

## Requirements

- Python 3.9.18
- PyTorch 2.5.1+cu121
- PyTorch Geometric 2.6.1
- TensorBoard 2.18.0

## Project Structure

```
.
├── Pretrain/              # Pretraining code for each model
│   ├── GAE/               # GAE pretraining scripts
│   ├── GRACE/             # GRACE pretraining scripts
│   ├── GraphCL/           # GraphCL pretraining scripts
│   ├── SGRL/              # SGRL pretraining scripts
│   └── SupGCL/            # SupGCL pretraining scripts
│ 
└── Finetuning/            # Downstream task scripts
    ├── BP_class.py         # Biological Process classification
    ├── CC_class.py         # Cellular Component classification
    ├── cancer_rel_class.py # Cancer relation classification
    ├── hazard_pred.py      # Patient hazard prediction
    └── subtype_class.py    # Cancer subtype classification

```

Each directory contains:
- `train.py`: Training script
- `main.py`: Entry point
- `models.py`: Model architectures
- `data.py`: Data loading and processing
- `utils.py`: Utility functions

## Usage

**[ Pretrain ]**

Each method can be run independently. For example, to run SupGCL pretraining:

```bash
cd SupGCL/Pretrain/SupGCL
python main.py \
  --tcga_path path/to/tcga_graphs.pkl \
  --lincs_graphs path/to/LINCS_KD_graphs.pkl \
  --lincs_meta_graphs data/meta_data/Breast/Breast_LINCS_KD_graphs_metadata.pkl \
  --lincs_meta_kd    data/meta_data/Breast/LINCS_sampleID_KDgene_metadata.pkl \
  --batch_size 4 \
  --epochs 3000 \
  --lr 1e-4 \
  --hid 64 \
  --out 64 \
  --tau_nce 0.25 \
  --tau_aug 0.25
```


**Main arguments**

- `--tcga_path` : Path to TCGA graph data (`.pkl`)  
- `--lincs_graphs` : Path to LINCS knockdown graph data (`.pkl`)  
- `--lincs_meta_graphs` : Path to LINCS graph-order metadata (`.pkl`)  
- `--lincs_meta_kd` : Path to LINCS sample→gene metadata (`.pkl`)  
- `--batch_size` : Training batch size  
- `--epochs` : Number of training epochs  
- `--lr` : Learning rate  
- `--hid` : Hidden GNN channel size  
- `--out` : Output embedding dimension  
- `--tau_nce` : Temperature for contrastive (InfoNCE) loss  
- `--tau_aug` : Temperature for augmentation alignment loss  

**[ Finetuning ]**

Each downstream task can be run with a unified script. For example, to run hazard prediction:

```bash
cd SupGCL/Finetuning
python hazard_pred.py \
  --method supgcl \
  --ckpt path/to/supgcl_checkpoint.pt \
  --data path/to/tcga_graphs.pkl \
  --meta path/to/tcga_graphs_metadata.pkl \
  --surv path/to/tcga_survival_time.csv \
```
**Main arguments**

- `--method`     : Pretraining method to load (`supgcl`, `gae`, `grace`, `graphcl`, `sgrl`, or `w_o_pretrain` for from scratch)  
- `--ckpt`       : Path to pretrained checkpoint (`.pt`) (ignored if `--method w_o_pretrain`)  
- `--data`       : Path to pickled graphs (`.pkl`)  
- `--meta`       : Path to pickled metadata (patient ID list, `.pkl`)  
- `--surv`       : Path to survival CSV (with OS time & event)  
- `--batch-size` : Batch size for finetuning (default: 8)  
- `--lr`         : Learning rate (default: 1e-3)  
- `--epochs`     : Number of finetuning epochs (default: 500)  
- `--patience`   : Early stopping patience (default: 5)  
- `--n-splits`   : Number of CV folds (default: 10)  
- `--hid-ch`     : Hidden channel size in GNN (default: 64)  
- `--out-ch`     : Output embedding dimension (default: 64)  
- `--heads`      : Number of attention heads (default: 8)  
- `--seed`       : Random seed (default: 0)  

## Data

### 1. GRN Dataset & Preprocessing  
All graph data objects live under `data/GRN_dataset/` and are generated by the preprocessing script:

```plaintext
SupGCL/
└── data/
    └── GRN_dataset/
        ├── Breast/
        │   ├── Edge_feature/
        │   │   ├── Breast_tcga_ecv.csv      # TCGA patient-specific edge contribution values (ECv)
        │   │   └── Breast_lincs_kd_ecv.csv  # LINCS knockdown edge contribution values
        │   ├── Node_feature/
        │   │   ├── Breast_TCGA_exp.csv       # TCGA patient-specific gene expression levels
        │   │   └── Breast_LINCS_KD_exp.csv   # LINCS knockdown gene expression levels
        │   └── make_GRN_dataset/
        │       └── mk_GRN.py                 # builds per-sample PyG graphs and pickles them
        ├── Colorectal/…                       # same structure for colorectal cancer
        └── Lung/…                             # same structure for lung cancer
```

- Edge_feature: 1D scalar “edge contribution values” (ECv) per sample
- Node_feature: 1D scalar gene expression per sample
- mk_GRN.py: combines node + edge features into torch_geometric.data.Data objects and serializes them.

### 2. Label Data for Finetuning
All task labels live under `data/labels/`:

```plaintext
SupGCL/
└── data/
    └── labels/
        ├── BP_data/
        │   └── gene_with_BP_multilabels.csv    # GO-BP multilabels (shared)
        ├── CC_data/
        │   └── gene_with_CC_multilabels.csv    # GO-CC multilabels (shared)
        ├── Cancer_rel_data/
        │   └── gene_with_cancer_relation.csv   # Cancer-relation labels (shared)
        ├── Subtype_data/
        │   └── Breast/subtype.csv              # Breast cancer subtype per patient
        └── Survival_data/
            ├── Breast/tcga_survival_time.csv   # OS time & event for hazard prediction
            ├── Colorectal/…
            └── Lung/…
```

### 3. Metadata
Helper files for mapping and filtering samples, under `data/meta_data/`:

```plaintext
SupGCL/
└── data/
    └── meta_data/
        └── Breast/
            ├── Breast_LINCS_KD_graphs_metadata.pkl   # Order of LINCS KD graphs
            ├── LINCS_sampleID_KDgene_metadata.pkl    # Map LINCS sample → knocked-down gene
            └── Breast_tcga_graphs_metadata.pkl       # Order of TCGA patient graphs
```
- LINCS metadata: used by SupGCL pretraining to match teacher (knockdown) graphs
- TCGA metadata: used in finetuning to filter only patients with survival/subtype annotations (via --meta)

## Citation

If you use this code in your research, please cite the original papers of the respective methods.

## License
The Cancer Genome Atlas (TCGA) Research Network .
TCGA TARGET GTEx data were accessed through the UCSC Xena Browser:  
Derived data available at: https://xenabrowser.net/datapages/?cohort=TCGA%2520TARGET%2520GTEx&removeHub=https%253A%252F%252Fxena.treehouse.gi.ucsc.edu%253A443

Subramanian A, et al. "A Next Generation Connectivity Map: L1000 Platform
and the First 1,000,000 Profiles." Cell 2017.
LINCS Program, NIH. L1000 data available at GEO: GSE.92742
