# scUniGP

scUniGP is a multi-scale fusion model based on Transformer and GNN for single-cell gene regulatory network prediction.

## 🚀 Quick Start

### Installation

1. Clone the repository:
```bash
git clone https://github.com/yourusername/scUniGP.git
cd scUniGP
```

2. Install dependencies:
```bash
pip install -r requirements.txt
```

### Quick Run

```bash
# Step 1: Data preprocessing
cd Data_processing
python pre-processing.py

# Step 2: Generate GNN embeddings
cd ../GNN_embedding_generate
bash run_all.sh gat

# Step 3: Train scUniGP model
cd ..
bash run.sh
```

## 📋 Requirements

- Python 3.8+
- PyTorch 1.8+
- CUDA (optional, for GPU acceleration)

Install all dependencies:
```bash
pip install -r requirements.txt
```

## 🏗️ Model Architecture

scUniGP adopts a multi-scale fusion architecture:
- **Transformer Encoder**: Processes gene expression data
- **GNN Embeddings**: Provides graph structure-aware features
- **Multi-scale Fusion**: Combines different levels of GNN embeddings
- **Loss Functions**: Supports multiple loss function combinations

## 📁 Data Preparation

### Input Data Structure

Place your raw data in the following structure:
```
yourpath/
├── GRN_GT/
│   ├── Specific Dataset/
│   ├── Non-Specific Dataset/
│   ├── STRING Dataset/
│   └── Lofgof Dataset/
```

### Output Data Structure

After preprocessing, the data will be organized as:
```
yourpath/scUniGP/Data_processing/Dataspilt/
├── Specific/
│   ├── hESC/
│   │   ├── TFs_500/
│   │   └── TFs_1000/
│   └── ...
├── Non-Specific/
├── STRING/
└── Lofgof/
```

## 🔄 Complete Running Steps

### Step 1: Data Preprocessing

Generate training, validation, and test sets from raw data:

```bash
cd yourpath/scUniGP/Data_processing
python pre-processing.py
```

This processes raw data from `yourpath/scUniGP/GRN_GT/` and saves processed datasets to `yourpath/scUniGP/Data_processing/Dataspilt/`.

### Step 2: Generate GNN Embeddings

Create graph embeddings using GAT or GCN models:

```bash
cd yourpath/scUniGP/GNN_embedding_generate

# Run GAT model for all datasets
bash run_all.sh gat

# Run GCN model for all datasets  
bash run_all.sh gcn

# Run for specific dataset
bash run_all.sh gat /path/to/specific/dataset
bash run_all.sh gcn /path/to/specific/dataset
```

**Generated Files:**
- `TF_gat_Channel1.csv` / `gcn_TF_Channel1.csv`: TF embeddings
- `Target_gat_Channel2.csv` / `gcn_Target_Channel2.csv`: Target embeddings
- `gene_gat1_embedding128.csv` / `gene_gcn1_embedding128.csv`: Layer 1 embeddings
- `gene_gat2_embedding64.csv` / `gene_gcn2_embedding64.csv`: Layer 2 embeddings
- `gat_predictions.csv` / `gcn_predictions.csv`: GNN prediction scores

### Step 3: Train scUniGP Model

Run the main scUniGP model training:

```bash
cd yourpath/scUniGP

# Run with default parameters
bash run.sh

# Run with specific dataset
bash run.sh /path/to/dataset

# Run with custom parameters
bash run.sh /path/to/dataset --batch_size 256 --epochs 100
```

## 📊 Data Format

Each dataset directory should contain:

| File | Description |
|------|-------------|
| `BL--ExpressionData.csv` | Gene expression data |
| `Train_set.csv` | Training set |
| `Validation_set.csv` | Validation set |
| `Test_set.csv` | Test set |
| `TF.csv` | Transcription factor information |
| `TF_gat_Channel1.csv` | TF GNN embeddings |
| `Target_gat_Channel2.csv` | Target GNN embeddings |
| `gene_gat1_embedding128.csv` | Layer 1 GNN embeddings |
| `gene_gat2_embedding64.csv` | Layer 2 GNN embeddings |
| `gat_predictions.csv` | GNN prediction scores (optional) |

## ⚙️ Configuration

### Model Parameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `--batch_size` | 512 | Batch size |
| `--embed_size` | 1024 | Embedding dimension |
| `--num_layers` | 4 | Number of Transformer layers |
| `--num_head` | 8 | Number of attention heads |

### Training Parameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `--lr` | 0.000005 | Learning rate |
| `--epochs` | 80 | Number of training epochs |
| `--patience` | 8 | Early stopping patience |
| `--step_size` | 10 | Learning rate scheduler step size |
| `--gamma` | 0.999 | Learning rate decay factor |

### Loss Function Parameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `--loss_type` | combined | Loss function type (bce/focal/weighted_bce/combined) |
| `--focal_weight` | 0.7 | Focal Loss weight |
| `--bce_weight` | 0.2 | BCE Loss weight |
| `--consistency_weight` | 0.1 | Consistency loss weight |

## 📈 Usage Examples

### Basic Training

```bash
python demo.py --data_dir /path/to/your/data
```

### Advanced Training

```bash
python demo.py \
    --batch_size 512 \
    --embed_size 1024 \
    --num_layers 4 \
    --num_head 8 \
    --lr 0.000005 \
    --epochs 80 \
    --loss_type combined \
    --focal_weight 0.7 \
    --bce_weight 0.2 \
    --consistency_weight 0.1
```

## 📁 Project Structure

```
scUniGP/
├── 📁 Data_processing/
│   ├── pre-processing.py          # Data preprocessing script
│   └── Dataspilt/                # Processed datasets
├── 📁 GNN_embedding_generate/
│   ├── main_gat.py               # GAT model training
│   ├── main_gcn.py               # GCN model training
│   ├── scGAT.py                  # GAT model definition
│   ├── scGCN.py                  # GCN model definition
│   ├── utils.py                  # Utility functions
│   └── run_all.sh                # GNN embedding generation script
├── 📄 model.py                   # scUniGP model definition
├── 📄 main.py                    # Main training function
├── 📄 demo.py                    # Single dataset training script
├── 📄 train_val.py               # Training and validation functions
├── 📄 dataset.py                 # Data loading and preprocessing
├── 📄 utils.py                   # Utility functions
├── 📄 losses.py                  # Loss function definitions
├── 📄 run.sh                     # Main model training script
└── 📄 requirements.txt           # Python dependencies
```

## 🎯 Model Features

- ✅ **Multi-scale Fusion**: Combines different levels of GNN embeddings
- ✅ **Flexible Loss Functions**: Supports multiple loss function combinations
- ✅ **Early Stopping**: Prevents overfitting
- ✅ **Model Saving**: Automatically saves the best model
- ✅ **Performance Monitoring**: Real-time monitoring of training and validation performance

## 📊 Output Results

During training, the following will be output:
- Training and validation performance for each epoch
- Best model save path
- Final average and best performance metrics

**Result files are saved in:**
- `./model_pt/`: Model weight files
- `result_tf500.log`: Training result logs

## 🤝 Contributing

1. Fork the repository
2. Create your feature branch (`git checkout -b feature/AmazingFeature`)
3. Commit your changes (`git commit -m 'Add some AmazingFeature'`)
4. Push to the branch (`git push origin feature/AmazingFeature`)
5. Open a Pull Request

## 📝 License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

## 📞 Contact

- **Author**: Your Name
- **Email**: your.email@example.com
- **Project Link**: https://github.com/yourusername/scUniGP

## 🙏 Acknowledgments

- Thanks to the open-source community for providing excellent tools and libraries
- Special thanks to PyTorch team for the amazing deep learning framework 