# Unsupervised SE(3) Disentanglement for in situ Macromolecular Morphology Identification from Cryo-Electron Tomography

This repository contains the implementation of ICLR 2026 submission 9125.

## Prerequisites

Install the required packages:

```bash
pip install torch torchvision numpy matplotlib kornia functorch mrcfile umap-learn scikit-learn scipy
```

## Dataset

The code expects data in pickle format. We included a sample dataset in
```
datasets/yeast_mwa_30_snr_001_cellular_protein_mixture_subtomograms.pkl
```

Each entry should contain a 'subtomo' field with 48x48x48 subtomogram data.

## How to Run

### Training and Inference

To train the model and run inference:

```bash
python harmony3D_cleaned.py
```

This will use the default dataset in datasets/ folder. For training on other datasets, set the `dataset_name` accordingly in the beginning of the script.

To run inference only with a pre-trained model, set `train=False` in the script. It will infer the unsupervised classsification accuracy using the pretrained model. We provided a pretrained model under models/ directory.

### Template Creation

To create template morphologies:

```bash
python create_template_morphology.py
```

This will create template morphologies and save them as `.mrc` format. The code uses the default dataset in datasets/ folder and the pretrained model in models/ folder. To change from default settings, set the `dataset_name` and `model_name` accordingly in the begining of the script. The value of `K` determines the number of template morphology classes determined with our method. Set it accordingly.

## What the Code Does

1. **Training**: Trains a SE(3) disentanglement network on 3D subtomograms
2. **Feature Extraction**: Extracts 50D SE(3) disentangled latent representations from trained model
3. **Clustering**: Performs UMAP embedding and Gaussian Mixture Model clustering to classify macromolecular morphology in the subtomograms
4. **Evaluation**: Computes clustering metrics (ARI score, Hungarian matching accuracy) for the provided dataset
5. **Visualization**: Generates UMAP plots and template morphologies from the decoder

## Output Files

- Model checkpoints: `models/*.pt` files
- Loss curves: `models/loss_curve_*.png`
- UMAP visualizations: `results/umap_*.png`
- Decoded structures: `results/*.png` and `results/*.mrc`


