# E2Former: A Linear Time Equivariant and Efficient Transformer for Molecular Simulations
E2Former is a state-of-the-art transformer model designed for molecular simulations, offering linear time complexity and equivariant properties. This implementation is built upon the foundations of several excellent projects:
- [FairChem](https://github.com/FAIR-Chem/fairchem)
- [EScAIP](https://github.com/ASK-Berkeley/EScAIP)
- [e3nn](https://github.com/e3nn/e3nn)


## Key Components

- **Node Embedding**: Default configuration of `128x0e+64x1e+32x2e`
- **Attention Mechanism**: 
  - Configurable number of attention heads (default: 4)
  - Customizable scalar head size (default: 32)
  - Support for various attention types and biases
- **Edge Features**: 
  - Gaussian radial basis functions
  - Customizable basis size and cutoff radius
- **Normalization**: Support for various normalization layers including RMS and layer normalization

## Installation

### Prerequisites
- Python 3.7+
- CUDA-compatible GPU (recommended)

### Step 1: Install Mamba Solver (Optional but Recommended)
```bash
conda install mamba -n base -c conda-forge
```

### Step 2: Create and Activate Environment
```bash
mamba env create -f env.yml
conda activate e2former
```

### Step 3: Install FairChem Core Package
```bash
git submodule update --init --recursive
pip install -e fairchem/packages/fairchem-core
```

## Usage

### Training on OC20 Dataset

1. Configure your data paths and parameters in `finetune_psm_oc20.sh`
2. Set up your Weights & Biases (wandb) key:
   - Replace "DUMMYKEY" with your actual wandb key
   - This is required for experiment tracking and visualization

3. Start training:
```bash
bash finetune_psm_oc20.sh
```

### Model Configuration

The model can be configured through YAML files located in the `configs` directory. Key configuration options include:

```yaml
backbone_config:
  irreps_node_embedding: "128x0e+128x1e+128x2e"
  num_layers: 8
  pbc_max_radius: 12.0
  max_radius: 15.0
  basis_type: "gaussian"
  number_of_basis: 128
  num_attn_heads: 4
  attn_scalar_head: 32
  irreps_head: "32x0e+32x1e+32x2e"
```
