# Transformer-Graph Models

This directory contains a unified interface for various transformer architectures designed for graph processing tasks. The models are organized in a clean, modular structure with a consistent API for easy experimentation and comparison.

## 🏗️ Architecture Overview

The models package provides three main transformer architectures:

1. **RoBERTa Graph Model** - Traditional transformer architecture adapted for graph adjacency matrices
2. **Looped Transformer** - Iterative transformer that loops through the same layer multiple times
3. **Disentangled Transformer** - Custom attention mechanism with disentangled components

## 🚀 Quick Start

```python
# Import the unified interface
from models import create_model, list_models

# List available models
print("Available models:", list_models())
# Output: ['roberta', 'looped_transformer', 'disentangled_transformer']

# Create a RoBERTa model
model = create_model(
    'roberta',
    num_nodes=32,
    hidden_size=128,
    num_layers=6,
    roberta_type='relu'
)

# Create a looped transformer
model = create_model(
    'looped_transformer',
    num_nodes=64,
    hidden_size=256,
    num_layers=10,
    read_in_method='zero_pad'
)

# Create a disentangled transformer
model = create_model(
    'disentangled_transformer',
    num_nodes=32,
    heads=[8, 8, 4],
    init_type='randn'
)

# Use the model
import torch
x = torch.randn(batch_size, num_nodes, num_nodes)  # Adjacency matrix
output = model.forward(x)  # Predicted adjacency matrix
hidden_states = model.get_hidden_states(x)  # Intermediate representations
```

## 📁 Directory Structure

```
models/
├── __init__.py                     # Package initialization and exports
├── models.py                       # Unified interface and model registry
├── README.md                       # This file
├── roberta/                        # RoBERTa-based models
│   ├── __init__.py
│   ├── robertaModels.py           # Main RoBERTa graph model
│   ├── preLayerNorm/              # Pre-layer norm variants
│   │   ├── __init__.py
│   │   ├── relu.py                # ReLU activation variant
│   │   ├── softmax.py             # Softmax activation variant
│   │   └── tie_qk.py              # Tied query-key variant
│   └── postLayerNorm/             # Post-layer norm variants
│       ├── __init__.py
│       ├── relu.py                # ReLU activation variant
│       ├── softmax.py             # Softmax activation variant
│       └── tie_qk.py              # Tied query-key variant
├── loopedTransformer/             # Looped transformer implementation
│   ├── __init__.py
│   └── loopedTransformerModel.py  # Main looped transformer
└── disentangledTransformer/       # Disentangled attention implementation
    ├── __init__.py
    └── disentangledTransformerModels.py  # Main disentangled transformer
```

## 🎯 Model Details

### RoBERTa Graph Model

A traditional transformer architecture adapted for graph processing tasks.

**Key Parameters:**
- `num_nodes`: Number of nodes in the graph
- `hidden_size`: Hidden dimension size (default: 128)
- `num_attention_heads`: Number of attention heads (default: 1)
- `num_layers`: Number of transformer layers (default: 12)
- `roberta_type`: Activation type - "relu", "softmax", "tie_qk" (default: "relu")
- `layer_norm_type`: Layer norm position - "pre" or "post" (default: "pre")

**Example:**
```python
model = create_model(
    'roberta',
    num_nodes=32,
    hidden_size=256,
    num_layers=8,
    num_attention_heads=4,
    roberta_type='relu',
    layer_norm_type='pre'
)
```

### Looped Transformer

An iterative transformer that processes the input through the same layer multiple times, allowing for deeper reasoning with fewer parameters.

**Key Parameters:**
- `num_nodes`: Number of nodes in the graph
- `hidden_size`: Hidden dimension size (default: 128)
- `num_layers`: Number of iterations/loops (default: 5)
- `num_attention_heads`: Number of attention heads (default: 1)
- `read_in_method`: Input encoding - "linear" or "zero_pad" (default: "linear")
- `layernorm_type`: Layer norm position - "pre" or "post" (default: "pre")
- `tie_qk`: Whether to tie query and key weights (default: False)

**Example:**
```python
model = create_model(
    'looped_transformer',
    num_nodes=48,
    hidden_size=128,
    num_layers=10,
    num_attention_heads=2,
    read_in_method='linear',
    tie_qk=True
)
```

### Disentangled Transformer

A custom attention mechanism with disentangled components, designed for specific graph reasoning tasks.

**Key Parameters:**
- `num_nodes`: Number of nodes in the graph
- `heads`: Number of heads per layer (default: [4, 4, 4])
- `extra_pos_id`: Whether to add positional embeddings (default: True)
- `init_type`: Weight initialization - "randn", "zeros", "eye", "psd", "sym" (default: "randn")
- `readout_type`: Output method - "linear", "sum", "last" (default: "linear")

**Example:**
```python
model = create_model(
    'disentangled_transformer',
    num_nodes=32,
    heads=[8, 8, 4, 2],
    extra_pos_id=True,
    init_type='psd',
    readout_type='sum'
)
```

## 🔧 Advanced Usage

### Custom Model Configuration

```python
from models import UnifiedModelRegistry

# Get detailed model information
info = UnifiedModelRegistry.get_model_info('roberta')
print(f"Model: {info['model_type']}")
print(f"Class: {info['class_name']}")
print("Parameters:")
for param, desc in info['parameters'].items():
    print(f"  - {param}: {desc}")
```

### Accessing Internal Components

```python
# Create a model
model = create_model('roberta', num_nodes=32, hidden_size=128)

# Access the underlying model for advanced operations
underlying_model = model.model  # Direct access to RobertaModelForGraph

# Get model name and configuration
print(f"Model name: {underlying_model.name}")
```

### Batch Processing

```python
import torch

# Create model
model = create_model('looped_transformer', num_nodes=32, num_layers=5)

# Process batch of adjacency matrices
batch_size = 16
adjacency_matrices = torch.randn(batch_size, 32, 32)

# Forward pass
outputs = model.forward(adjacency_matrices)  # Shape: [16, 32, 32]

# Get intermediate states
hidden_states = model.get_hidden_states(adjacency_matrices)
print(f"Number of hidden states: {len(hidden_states)}")
print(f"Hidden state shapes: {[h.shape for h in hidden_states]}")
```

## 🧪 Model Comparison

```python
import torch
from models import create_model

# Create different models with same parameters
models = {
    'roberta': create_model('roberta', num_nodes=32, hidden_size=128, num_layers=4),
    'looped': create_model('looped_transformer', num_nodes=32, hidden_size=128, num_layers=4),
    'disentangled': create_model('disentangled_transformer', num_nodes=32, heads=[4, 4])
}

# Test input
x = torch.randn(4, 32, 32)

# Compare outputs
for name, model in models.items():
    output = model.forward(x)
    print(f"{name}: {output.shape}")
    
    # Count parameters
    total_params = sum(p.numel() for p in model.model.parameters())
    print(f"{name} parameters: {total_params:,}")
```

## 📚 Legacy Compatibility

For backward compatibility, the original model classes are still available:

```python
# Direct imports (not recommended for new code)
from models.roberta.robertaModels import RobertaModelForGraph
from models.loopedTransformer.loopedTransformerModel import LoopedTransformer
from models.disentangledTransformer.disentangledTransformerModels import DisentangledTransformer

# Legacy convenience functions
from models import create_roberta_model, create_looped_transformer, create_disentangled_transformer

# These work the same as create_model() but with specific types
roberta_model = create_roberta_model(num_nodes=32, hidden_size=128)
```

## 🛠️ Development Guidelines

When adding new models to this package:

1. **Create a new subdirectory** under `models/` for your model
2. **Implement the model class** inheriting from `nn.Module`
3. **Create a wrapper class** inheriting from `BaseGraphModel`
4. **Add the wrapper** to `UnifiedModelRegistry.MODEL_MAP`
5. **Update documentation** and examples
6. **Add comprehensive tests**

Example structure for a new model:
```python
class NewGraphModel(BaseGraphModel):
    def __init__(self, num_nodes: int, **kwargs):
        self.model = YourNewModel(num_nodes=num_nodes, **kwargs)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
    
    def get_hidden_states(self, x: torch.Tensor) -> List[torch.Tensor]:
        return self.model.get_hidden_states(x)
```

## 🔍 Troubleshooting

### Common Issues

1. **Import Errors**: Make sure you're importing from the `models` package correctly
2. **Shape Mismatches**: Ensure `num_nodes` matches your adjacency matrix dimensions
3. **Memory Issues**: Reduce `hidden_size` or `num_layers` for large graphs
4. **Performance**: Use appropriate `batch_size` and consider gradient checkpointing

### Getting Help

- Check the model info: `get_model_info(model_type)`
- Review the examples in `models.py`
- Ensure your input tensors have the correct shape: `[batch_size, num_nodes, num_nodes]`

## 📈 Performance Tips

1. **Use appropriate model size** for your graph size
2. **Batch processing** for multiple graphs
3. **Mixed precision training** with `torch.cuda.amp`
4. **Gradient accumulation** for large batches
5. **Model parallelism** for very large models

---

For more examples and detailed usage, see the main project README.md or run the examples in `models.py` directly.
