# GrapHist

This code contains the data processing, training, and evalution steps of GrapHist.
The pre-trained model and constructed graph datasets will be made available upon publication.

### Data Processing

1. **Cell Segmentation** (`src/data/01_cell_segmentation.py`)
   - Uses StarDist2D pre-trained model for cell detection
   - Extracts cell coordinates and segmentation probabilities
   - Supports multiple image formats (PNG, TIFF, SVS)

2. **Cell Feature Extraction** (`src/data/02_cell_feature_extraction.py`)
   - Computes morphological features
   - Extracts texture features using Gray-Level Co-occurrence Matrix (GLCM)
   - Calculates intensity statistics and Fourier descriptors

3. **Graph Construction** (`src/data/03_graph_construction.py`)
   - Builds spatial graphs using Delaunay triangulation
   - Filters edges based on distance threshold
   - Creates node and edge features for graph neural networks

4. **Dataset Preparation** (`src/data/04_save_dataset.py`)
   - Converts processed data to PyTorch Geometric format
   - Handles clinical data integration
   - Prepares train/validation/test splits

## Models, Training, and Evaluation

### Self-Supervised Pre-training

- **Architecture**: Encoder-Decoder with ACM-GIN (Adaptive Channel Mixing Graph Isomorphism Network)
- **Pre-training Task**: Masked node attribute prediction
- **Loss Function**: SCE (Scaled Cosine Error) loss

### Evaluation

- **MIL**: Attention, additive, and conjunctive classifiers for slide-level tasks
- **Linear Probing**: Logistic regression for patch-level tasks

## Installation

### Environment Setup

```bash
# Create conda environment from the provided yaml file
conda env create -f env.yaml
conda activate graphist
```

### Dependencies

Key dependencies include:
- PyTorch 2.2.2 with CUDA 11.8 support
- PyTorch Geometric 2.5.2
- TensorFlow 2.x (for StarDist)
- StarDist for cell segmentation
- OpenSlide for WSI processing
- scikit-learn, pandas, numpy

## Usage

### 1. Cell Segmentation

```bash
python src/data/01_cell_segmentation.py \
    --input_path /path/to/images \
    --output_path /path/to/segmentation_results
```

### 2. Cell Feature Extraction

```bash
python src/data/02_cell_feature_extraction.py \
    --input_path /path/to/images \
    --seg_path /path/to/segmentation_results \
    --output_path /path/to/features \
```

### 3. Graph Construction

```bash
python src/data/03_graph_construction.py \
    --feat_path /path/to/features \
    --output_path /path/to/graphs \
    --dist_threshold 100
```

### 4. Dataset Preparation

```bash
python src/data/04_save_dataset.py \
    --edge_path /path/to/graphs \
    --clinical_data_path /path/to/clinical_data \
    --output_path /path/to/dataset
```

### 5. Model Pre-training

```bash
python src/train/pretrain.py \
    --dataset TCGA_BRCA \
    --save_folder /path/to/checkpoints \
    --scale_vals_path /path/to/normalization.json \
    --sample_data_folder /path/to/dataset \
    --sample_split_folder /path/to/splits \
    --mask_rate 0.5 \
    --replace_rate 0.1 \
    --num_hidden 512 \
    --num_layers 5 \
    --max_epoch 100 \
    --batch_size 2048
```

### 6. Generate Embeddings

```bash
python src/train/generate_embs.py \
    --model_path /path/to/pretrained_model.pt \
    --data_path /path/to/graphs \
    --output_path /path/to/embeddings \
    --batch_size 1024
```

### 7. Classification

```bash
# Slide-level evaluation
python src/evaluate/main_slide.py \
    --dataset TCGA_BRCA \
    --label_path /path/to/labels.csv \
    --train_data_path /path/to/train_embeddings \
    --test_data_path /path/to/test_embeddings \
    --project_dir /path/to/project \
    --save_dir /path/to/results

# Patch-level evaluation
python src/evaluate/main_patch.py --organ organ
```

## Data Format

### Input Data Structure
```
data/
├── images/
│   ├── patient_001/
│   │   ├── patch_001.png
│   │   └── patch_002.png
│   └── patient_002/
└── clinical_data/
    └── sample_labels.csv
```

### Output Data Structure
```
results/
├── segmentation/
│   ├── patient_001/
│   │   ├── patch_001_coords.npy
│   │   └── patch_001_probs.csv
├── features/
│   ├── patient_001/
│   │   └── patch_001_features_20x.csv
├── graphs/
│   ├── patient_001/
│   │   └── patch_001_edges.csv
└── embeddings/
    ├── patient_001/
    │   └── embeddings.h5
```
