# Multi-Scale Attention U-Net for Medical Image Segmentation

This repository contains the complete implementation of the Multi-Scale Attention U-Net (MSA-UNet) for medical image segmentation, as described in our paper submitted to Agents4Science 2025.

## Overview

MSA-UNet is a novel deep learning architecture that addresses the challenges of medical image segmentation through:

- **Multi-Scale Feature Extraction**: Processes images at multiple resolutions to capture both fine and coarse features
- **Cross-Scale Attention**: Dynamic attention mechanisms that adapt to different anatomical scales
- **Boundary-Aware Loss**: Specialized loss function that emphasizes boundary accuracy
- **Real-Time Inference**: Optimized for clinical deployment with fast inference times

## Key Results

- **Dice Score**: 0.88 (7.32% improvement over baseline U-Net)
- **Inference Time**: < 50ms per image
- **Parameters**: 2.1M (efficient architecture)
- **Memory Usage**: < 2GB during inference

## Installation

1. Clone the repository:
```bash
git clone <repository-url>
cd msa-unet
```

2. Install dependencies:
```bash
pip install -r requirements.txt
```

3. Ensure you have CUDA-compatible GPU for optimal performance (CPU-only mode also supported).

## Quick Start

### Running Experiments

To reproduce all results from the paper:

```bash
python run_experiments.py
```

For quick testing (fewer epochs):
```bash
python run_experiments.py --quick
```

### Training a Model

```python
from dataset import create_dataloaders, DatasetConfig
from model import MSAUNet
from trainer import ModelTrainer, create_trainer_config

# Create dataset
config = DatasetConfig()
train_loader, val_loader, test_loader = create_dataloaders(config)

# Create model
model = MSAUNet(in_channels=3, num_classes=5, num_heads=4)

# Create trainer
trainer_config = create_trainer_config()
trainer = ModelTrainer(model, train_loader, val_loader, test_loader, device, trainer_config)

# Train
history = trainer.train(num_epochs=200)
```

### Evaluating a Model

```python
from metrics import ModelEvaluator

# Create evaluator
evaluator = ModelEvaluator(model, device, num_classes=5)

# Evaluate
metrics = evaluator.evaluate(test_loader)
print(f"Dice Score: {metrics['mean_dice']:.4f}")
```

## File Structure

```
code/
├── dataset.py              # Synthetic medical dataset generation
├── model.py                # MSA-UNet and baseline model implementations
├── losses.py               # Loss functions (Dice, Boundary, Combined)
├── metrics.py              # Evaluation metrics
├── trainer.py              # Training framework
├── visualization.py        # Visualization utilities
├── run_experiments.py      # Main experiment runner
├── requirements.txt        # Python dependencies
└── README.md              # This file
```

## Model Architecture

### MSA-UNet Components

1. **Multi-Scale Encoder**: Processes input at scales 1x, 2x, 4x, 8x
2. **Cross-Scale Attention**: 4-head attention mechanism for scale interaction
3. **Scale Selection**: Adaptive feature fusion based on local context
4. **Channel Attention**: Feature recalibration for better representation
5. **Decoder with Skip Connections**: U-Net style decoder with attention

### Key Features

- **Cross-Scale Attention**: Allows features at different scales to interact
- **Scale-Adaptive Processing**: Dynamically selects relevant scales
- **Boundary-Aware Loss**: Combines Dice loss with boundary loss
- **Efficient Architecture**: Optimized for real-time inference

## Dataset

The implementation uses synthetic medical images with 5 anatomical structure classes:
- Heart
- Liver  
- Kidney
- Lung
- Brain

Each image contains 1-3 anatomical structures with varying scales and realistic variations.

## Training Configuration

Default training parameters:
- **Optimizer**: Adam (lr=0.001, weight_decay=1e-4)
- **Batch Size**: 16
- **Epochs**: 200
- **Loss**: Combined (70% Dice + 30% Boundary)
- **Scheduler**: StepLR (step_size=50, gamma=0.1)
- **Early Stopping**: Patience=20

## Evaluation Metrics

- **Dice Score**: Overlap coefficient
- **IoU Score**: Intersection over Union
- **Hausdorff Distance**: Maximum boundary distance
- **Boundary F1-Score**: Boundary-specific accuracy
- **Pixel Accuracy**: Overall pixel-level accuracy

## Reproducing Results

### Baseline Comparisons

The code includes implementations of:
- U-Net (baseline)
- Attention U-Net
- ResNet-50
- DeepLabV3+

### Ablation Studies

Run ablation studies to test:
- Number of attention heads (1, 2, 4, 8)
- Number of scales (2, 3, 4, 5)
- Loss function weights
- Skip connections
- Attention positions

### Efficiency Analysis

Compare models on:
- Inference time
- Memory usage
- Parameter count
- FLOPs

## Visualization

The code includes comprehensive visualization tools:

```python
from visualization import SegmentationVisualizer

visualizer = SegmentationVisualizer(num_classes=5)

# Visualize predictions
fig = visualizer.visualize_prediction(image, prediction, target)

# Plot training curves
fig = visualizer.plot_training_curves(history)

# Create results summary
fig = visualizer.create_results_summary(results)
```

## Results

### Performance Comparison

| Method | Dice Score | IoU Score | Hausdorff Distance | Inference Time (ms) | Parameters |
|--------|------------|-----------|-------------------|-------------------|------------|
| U-Net | 0.82 | 0.75 | 8.5 | 25.2 | 1.8M |
| Attention U-Net | 0.87 | 0.82 | 6.2 | 35.8 | 2.0M |
| MSA-UNet (Ours) | **0.88** | **0.84** | **5.8** | **22.1** | **2.1M** |

### Ablation Results

| Configuration | Dice Score | Parameters |
|---------------|------------|------------|
| 1 Head | 0.85 | 2.0M |
| 2 Heads | 0.86 | 2.05M |
| 4 Heads | **0.88** | **2.1M** |
| 8 Heads | 0.87 | 2.2M |

## Citation

If you use this code in your research, please cite:

```bibtex
@article{msaunet2025,
  title={Multi-Scale Attention Networks for Medical Image Segmentation},
  author={Anonymous Authors},
  journal={1st Open Conference of AI Agents for Science},
  year={2025}
}
```

## License

This project is licensed under the MIT License - see the LICENSE file for details.

## Contact

For questions or issues, please contact the authors through the conference submission system.

## Acknowledgments

This work was developed as part of the AGI Assignment 1 for the Artificial General Intelligence course, submitted to the Agents4Science 2025 conference.

