# Crystal Fourier Transformer (CFT)
Transformer architecture using crystallographic Fourier basis functions for positional encoding.

## Setup

1. Create and activate the conda environment:
   ```
   conda env create -f environment.yml
   conda activate cft
   ```

2. Download Materials Project data:
   - Run:
     ```
     python -m data.get_materials --api-key YOUR_API_KEY
     ```

## Training Pipeline

### 1. Pretrain Positional Encodings

This step generates synthetic crystal data and trains an MLP on this dataset to learn positional encodings in the crystal Fourier basis.

```
python -m pretrain.gen_data
python -m pretrain.mlp
```

Current setup:
- Generate 5000 crystals with 50 atoms each and save to `data/synthetic_crystals`
- Sample lattice parameters based on Bravais lattice of the corresponding space group
- Train 2 ResNets (one for positions, one for lattice vectors) and multiply outputs to get positional encodings

Adjust these parameters in the respective scripts as needed.

### 2. Train Transformer Model

To train the full Crystal Fourier Transformer:

```
python train.py --fourier
```

Key options:
- `--fourier`: Use Fourier positional encoding (omit for naive sines/cosines encoding)
- `--num_epochs`: Number of training epochs
- `--batch_size`: Batch size for training
- `--learning_rate`: Initial learning rate
- `--num_attn_blocks`: Number of attention blocks in the transformer
- `--num_heads`: Number of attention heads per block

For a full list of options, run `python train.py --help`.
