# GraGR: Gradient-Guided Graph Reasoner

[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch](https://img.shields.io/badge/PyTorch-1.9+-red.svg)](https://pytorch.org/)
[![PyTorch Geometric](https://img.shields.io/badge/PyTorch%20Geometric-2.0+-green.svg)](https://pytorch-geometric.readthedocs.io/)

**GraGR** is a comprehensive gradient-guided graph reasoning framework that addresses gradient conflicts in graph neural networks through systematic conflict detection, resolution, and adaptive reasoning mechanisms. This repository contains the complete implementation of both GraGR Core and GraGR++ methodologies, along with extensive experimental validation, benchmarks, and interpretability tools.

## GraGR Activation in Gradient Slope

https://github.com/user-attachments/assets/15561379-8a04-4700-bcfe-b5908632398e


## GraGR resolving conflicts (Cora Dataset)

https://github.com/user-attachments/assets/37d7e7b7-ac2c-45e1-af35-c9129c90a739


## Architecture
<img width="717" height="531" alt="image" src="https://github.com/user-attachments/assets/f47f9799-3322-4d7b-bc53-e18639416e77" />


## 🚀 Quick Start

### Installation

```bash
# Clone the repository
git clone https://github.com/yourusername/gragr.git
cd gragr

# Install dependencies
pip install -r requirements.txt

# Download datasets (supports 20+ datasets)
python datasets/download/download_datasets.py

# Run quick demo (tested and working)
python demo/quick_start/demo.py --compare --epochs 100

# Run comprehensive experiments (150 epochs default)
python experiments/run_experiments.py --epochs 150

# Run benchmarks
python benchmarks/comprehensive_benchmark.py
```

### Basic Usage

```python
import torch
from src.core.gragr_complete import GraGRCore, GraGRPlusPlus

# Create model
model = GraGRCore('gcn', in_dim=1433, hidden_dim=64, out_dim=7, num_nodes=2708)

# Forward pass with reasoning
logits, signals = model.forward_with_reasoning(x, edge_index, epoch=0, total_epochs=100)

# Training
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss = F.cross_entropy(logits[train_mask], y[train_mask])
loss.backward()
optimizer.step()
```

## 📁 Repository Structure

https://www.overleaf.com/project/68386e4b79ddb7f87711fe4d/blob/8191bdb81268d6c45502a307b6702666f7ebe7f4<img width="885" height="209" alt="image" src="https://github.com/user-attachments/assets/be44123d-1e7b-4c0b-8171-ea95bc266552" />


## 🔬 Methodology

### GraGR Core (Components 1-4)

1. **Gradient Conflict Detection**: Identifies conflicting gradients between connected nodes
2. **Topology-Informed Gradient Alignment**: Applies Laplacian smoothing to resolve conflicts
3. **Gradient-Based Attention**: Uses gradient information to guide attention mechanisms
4. **Meta-Gradient Modulation**: Learns optimal task weighting for multi-task scenarios

### GraGR++ (All 6 Components)

5. **Multiple Pathways Framework**: Employs multiple reasoning pathways with adaptive gating
6. **Adaptive Scheduling**: Dynamically activates reasoning based on training progress

## 📊 Supported Datasets

### Node Classification
- **Citation Networks**: Cora, CiteSeer, PubMed
- **WebKB Datasets**: Texas, Cornell, Wisconsin
- **Structural Graphs**: WikiCS
- **Coauthor Networks**: CS, Physics
- **Amazon Networks**: Computers, Photo
- **Social Networks**: Reddit

### Graph Classification
- **Molecular Datasets**: OGB-MolHIV, MUTAG, PROTEINS, QM9
- **TU Datasets**: ENZYMES, NCI1, NCI109, COLLAB, IMDB-BINARY, IMDB-MULTI
- **Reddit Datasets**: REDDIT-BINARY, REDDIT-MULTI-5K, REDDIT-MULTI-12K

### Multi-Task Learning
- **Molecular Properties**: OGB-MolHIV multi-task properties
- **Synthetic Multi-Task**: Custom multi-task datasets
- **Real-World Multi-Task**: Various real-world multi-task scenarios

**Total**: 20+ datasets across different categories for comprehensive evaluation

## 🚀 Getting Started

### 1. Environment Setup

```bash
# Create virtual environment
python -m venv gragr_env
source gragr_env/bin/activate  # On Windows: gragr_env\Scripts\activate

# Install dependencies
pip install torch torch-geometric
pip install ogb
pip install matplotlib seaborn
pip install scikit-learn
pip install pandas numpy
```

### 2. Download Datasets

```bash
# Download all datasets
python datasets/download/download_datasets.py

# Download specific dataset categories
python datasets/download/download_datasets.py --datasets citation webkb
```

### 3. Run Experiments

```bash
# Run comprehensive experiments
python experiments/run_experiments.py --epochs 100

# Run specific experiment types
python experiments/run_experiments.py --skip_ablation
```

### 4. Run Benchmarks

```bash
# Run benchmark comparisons
python benchmarks/comprehensive_benchmark.py

# Run ablation studies
python ablation_studies/enhanced_ablation_study.py
```


## 🔧 Model Configuration

### GraGR Core

```python
model = GraGRCore(
    backbone_type='gcn',           # 'gcn', 'gat', 'gin', 'sage'
    in_dim=1433,                   # Input feature dimension
    hidden_dim=64,                 # Hidden dimension
    out_dim=7,                     # Output classes
    num_nodes=2708,                # Number of nodes
    num_tasks=1,                   # Number of tasks
    dropout=0.5,                   # Dropout rate
    dataset_name='cora'            # Dataset-specific tuning
)
```

### GraGR++

```python
model = GraGRPlusPlus(
    backbone_type='gcn',
    in_dim=1433,
    hidden_dim=64,
    out_dim=7,
    num_nodes=2708,
    num_tasks=1,
    dropout=0.5,
    num_pathways=3,                # Number of reasoning pathways
    eta_thresh=1e-4,               # Plateau detection threshold
    t_min=10,                      # Minimum epochs before reasoning
    dataset_name='cora'
)
```

## 📈 Performance

### Node Classification Results

| Dataset | Baseline GCN | GraGR Core | GraGR++ | Improvement |
|---------|--------------|------------|---------|-------------|
| Cora    | 81.5%        | 83.2%      | 84.1%   | +2.6%       |
| CiteSeer| 71.4%        | 74.8%      | 76.2%   | +4.8%       |
| PubMed  | 79.0%        | 81.3%      | 82.7%   | +3.7%       |
| WikiCS  | 75.2%        | 78.9%      | 80.1%   | +4.9%       |

### Key Features

- **Systematic Conflict Resolution**: Addresses gradient conflicts in graph learning
- **Adaptive Reasoning**: Dynamically activates reasoning based on training progress
- **Multi-Pathway Framework**: Employs multiple reasoning pathways for robust learning
- **Interpretable Components**: Provides gradient-based interpretability features
- **Scalable Architecture**: Works with various GNN backbones (GCN, GAT, GIN, SAGE)
- **Comprehensive Benchmarks**: 15+ benchmark scripts comparing against state-of-the-art methods
- **Extensive Visualizations**: 10+ visualization scripts for publication-quality plots
- **LLM Integration**: Natural language explanations using Groq API
- **Multi-Task Support**: Comprehensive multi-task learning capabilities
- **Ablation Studies**: Detailed component-wise analysis and validation

## 🔍 Interpretability

GraGR provides comprehensive interpretability through:

1. **Gradient Feature Extraction**: Six-dimensional gradient features for analysis
2. **Conflict Visualization**: Visual representation of gradient conflicts
3. **LLM Explanations**: Natural language explanations using gradient features
4. **Attention Analysis**: Gradient-guided attention weight visualization

```python
# Enable interpretability
from interpretability.gragr_interpretability_complete import InterpretableGraGRWrapper

interpretable_model = InterpretableGraGRWrapper(
    gragr_model=model,
    data_obj=data,
    enable_interpretability=True,
    groq_api_key="your_api_key"
)

# Extract gradient features
gradient_features = interpretable_model.extract_gradient_features()

# Generate LLM explanations
explanations = interpretable_model.generate_explanations()
```

## 🧪 Ablation Studies

Comprehensive ablation studies are available for:

- **Core Components**: Individual component analysis
- **GraGR++ Components**: Multiple pathways and adaptive scheduling
- **Backbone Analysis**: Performance across different GNN architectures
- **Dataset Analysis**: Performance across different dataset types

```bash
# Run ablation studies
python ablation_studies/enhanced_ablation_study.py
python ablation_studies/final_ablation_study.py
python ablation_studies/run_multi_dataset_ablation.py
```


## 🤝 Contributing

We welcome contributions! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details.

1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests if applicable
5. Submit a pull request

---

**GraGR**: Empowering Graph Neural Networks with Gradient-Guided Reasoning 🚀
