# A Closer Look at the Application of Causal Inference in Graph Representation Learning

## Overview

This repository utilizes various graph neural network architectures enhanced with causal filtering mechanisms to improve model robustness and generalizability.

## Key Features

- **Causal Graph Neural Networks**:  Using multiple GNN architectures with causal reasoning capabilities(GCN, GIN, ChebNet, DIR-GNN, CRCG, CaNet)
- **Confounded Dataset Generation**: Comprehensive synthetic dataset generators that create graphs with various types of confounding biases
- **Causal Filter Modules**: Advanced filtering mechanisms to mitigate spurious correlations and confounding effects
- **Multi-domain Evaluation**: Support for node classification, graph classification, and molecular property prediction tasks

## Project Structure

```
GRL/
├── models/                          # GNN model implementations
│   ├── CaNet/                      # CaNet implementation
│   ├── chebnet-main/              # Chebyshev Graph Neural Network
│   ├── CRCG-main/                 # CRCG implementation
│   ├── DIR-GNN-main/              # DIR-GNN implementation
│   ├── GCN/                        # Graph Convolutional Network
│   └── GIN/                        # Graph Isomorphism Network
├── gen_datasets/                   # Dataset generation utilities
│   ├── molecular/                 # Generate molecular-based synthetic graph datasets
│   ├── motif/                     # Generate motif-based synthetic graph datasets
│   └── paper/                     # Generate paper citation network datasets
└── requirements.txt               # Python dependencies
```

## Core Components

### 1. Causal Filter Mechanisms

#### ImprovedCausalFilter
- **Purpose**: Mitigate spurious correlations and confounding effects in graph features
- **Key Components**:
  - **Node Scoring**: MLP for robust node importance estimation
  - **Channel-wise Gating**: Feature channel selection based on causal relevance
  - **Adaptive Residual Connections**: Dynamic residual weights based on filtering intensity
  - **Temperature Annealing**: Gradual sparsity increase during training
  - **Sparsity Control**: Target-based sparsity regulation
  - **Output Calibration**: Statistical moment matching for stable training

#### Advanced Features
- **Adversarial Robustness**: Gradient-based adversarial training components
- **Feature Decorrelation**: Independence constraints between filtered features
- **Spectral Normalization**: Weight normalization for training stability
- **Gradient Penalty**: Regularization for smooth filtering behavior

### 2. Dataset Generation Framework

#### Confounded Graph Generation
The framework generates synthetic datasets with controlled confounding effects:

- **Dataset Types**:
  - `conf_X.X`: Datasets with confounding probability X.X
  - `int_X.X`: Intervention-based datasets
  - `causal_X_Y`: Causal relationship datasets
  - `intervened_X.X`: Post-intervention datasets
  - `element_confound_X_others_causal`: Element-specific confounding

## Dataset

The pre-generated datasets used in this project are available at:
**[https://anonymous.4open.science/r/GRL_data-iclr2026](https://anonymous.4open.science/r/GRL_data-iclr2026)**

This repository contains all the synthetic and real-world graph datasets with various confounding patterns and causal structures used for evaluation.

## Installation

### Prerequisites
- Python 3.8+
- CUDA 11.8 (for GPU support)
- PyTorch 2.5.1
- PyTorch Geometric 2.6.1

### Setup
```bash
# Clone the repository
git clone <repository-url>
cd GRL

# Install dependencies
pip install -r requirements.txt

# For CUDA support, ensure PyTorch is installed with CUDA 11.8
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
```