# Equivariant neural networks for general linear symmetries on Lie algebras

**PyTorch implementation of "Reductive Lie Neurons"**
<br>

**Anonymous Authors**
<br>
<!-- Author names and affiliations removed for double-blind review. -->

---

## About

Encoding symmetries is a powerful inductive bias for improving the generalization of deep neural networks. However, most existing equivariant models are limited to simple symmetries like rotations, failing to address the broader class of general linear transformations, GL(n), that appear in many scientific domains. introduce **Reductive Lie Neurons (ReLNs)**, a novel neural network architecture exactly equivariant to these general linear symmetries.

<img src="figures/applications.png" alt="Applications of ReLN across various scientific domains" width="70%">

*ReLNs are applicable to a wide range of scientific domains governed by diverse Lie group symmetries, from physics and robotics to computer vision.*

Unlike previous methods like [LieNeurons]([Link removed]), which are tailored for semi-simple Lie algebras (e.g., `so(3)`), our work introduces a general approach to construct **non-degenerate bilinear forms for any `n x n` matrix Lie algebra**, including reductive ones like `gl(n)`. This allows for the principled design of equivariant layers and nonlinearities for a much broader class of symmetries.

This repository provides the official code to reproduce the experiments in our paper.

---

## Core Concept: Adjoint Equivariance by Design

A key contribution of our work is a unified framework that embeds diverse geometric inputs (like vectors and covariance matrices) into a common Lie algebra, where they transform consistently under the **adjoint action**. Our network is designed to commute with this action, guaranteeing equivariance.

<img src="figures/equivariance_diagram.png" alt="Equivariance Diagram" width="60%">

*Our network `f` is provably equivariant. A transformation `Ad_g` on the input results in the same transformation `Ad_g` on the output feature.*

To achieve this for general reductive algebras like `gl(n)`, introduce a non-degenerate, Ad-invariant bilinear form:

`B(X, Y) = 2n * tr(XY) - tr(X)tr(Y)`

This form is the fundamental tool used to build our equivariant layers. Here’s a simple code snippet demonstrating how it creates an invariant feature:

```python
# From core/layers.py
import torch

class LNInvariant(nn.Module):
    """
    Computes an invariant scalar feature from a Lie algebra element
    using our non-degenerate bilinear form.
    """
    def __init__(self, in_channels, algebra_type='gl3'):
        super(LNInvariant, self).__init__()
        self.hat_layer = HatLayer(algebra_type) # Maps vector to matrix
        self.algebra_type = algebra_type

    def forward(self, x):
        """
        Input x: Lie algebra vectors
        Output: Invariant scalars
        """
        # 1. Map vector representation to matrix representation
        x_hat = self.hat_layer(x)

        # 2. Compute the invariant using the bilinear form B(X, X)
        invariant_scalar = killingform(x_hat, x_hat, self.algebra_type)

        # 3. Aggregate features (e.g., via mean)
        return invariant_scalar.mean(dim=[-2, -1])
```

For a more detailed interactive example, please see our [Toy Problem Notebook](examples/toy_problem.ipynb).

---

## Installation

To set up the environment, please follow these steps:

```bash
# Clone the repository
git clone <this-repository-url>
cd ReLN_2026
```
Each experiment has its own set of dependencies. Please refer to the `README` file within each experiment's directory (e.g., `./lorentznet/`) for specific installation instructions.

---

## Reproducing Paper Results

All experiment scripts are located in the `experiments/` directory. For each experiment, first download the required dataset and place it in the corresponding `data/` subfolder.

### Algebraic Benchmarks (`sl(3)` and `sp(4)`)
These experiments reproduce the Platonic Solid Classification and `sp(4)` Invariant Function Regression results. Our model directly adopts the architecture from Lie Neurons, replacing only the bilinear form.

For detailed instructions on data generation, training, and evaluation for these benchmarks, please refer to the original **[LieNeurons GitHub repository]([Link removed])**.

### Particle Physics: Top-Tagging (`SO(1,3)`)
This experiment reproduces the Top-Tagging benchmark results. The following command trains our ReLN-based model.

```bash
# Navigate to the experiment directory
cd ./lorentznet/

# Run training
torchrun --nproc_per_node=1 top_tagging.py \
    --batch_size=32 \
    --epochs=35 \
    --warmup_epochs=4 \
    --n_layers=5 \
    --n_hidden=48 \
    --lr=0.001 \
    --weight_decay=0.01 \
    --exp_name=reln_top_tagging_repro \
    --datadir ./data/toptag/
```

### Drone State Estimation (`SO(3)` with Uncertainty)
This experiment reproduces the drone trajectory estimation results. The main script allows you to train different model architectures by changing the `--arch` flag.

**To train our best-performing model (ReLN with log-covariance):**
```bash
# Navigate to the experiment directory
cd ./velocity_learning/

# Run training for the main ReLN model
python3 src/main_net.py \
    --mode train \
    --root_dir ./data_drone/ \
    --out_dir ./results/reln_log_cov/ \
    --epochs 200 \
    --arch ln_resnet_cov \
    --input_dim 6 
```

**To train other baseline models for comparison:**

You can reproduce the ablation studies in our paper by changing the `--arch` flag. Key architectures include:

* `--arch resnet`: Non-equivariant ResNet baseline.
* `--arch vn_resnet`: Equivariant Vector Neurons (VN) baseline (velocity only).
* `--arch vn_resnet_cov`: VN baseline adapted for covariance.
* `--arch ln_resnet`: Our ReLN model using only velocity information.
* `--arch ln_resnet_cov`: ReLN model using both velocity and covariance features (SO(3)-equivariant, GL(3)-based bilinear form).
* `--arch ln_resnet_cov_sl3`: ReLN (covariance) variant where the equivariant bilinear form is restricted to the traceless \(\mathfrak{sl}(3)\) part.
* `--arch ln_resnet_cov_no_sl3`: ReLN (covariance) variant that removes the \(\mathfrak{sl}(3)\) contribution in the bilinear form.

* `--arch tfn_resnet`: Tensor Field Network (TFN) ResNet baseline (velocity-only).  

* `--arch tfn_resnet_cov`: TFN-based ResNet with covariance.  
  Supports several input configurations:
  * 12D: velocity (3D) + general \(3\times3\) matrix (9D)  

* `--arch vn_transformer`: VN-Transformer regressor baseline (velocity-only).  
  Uses the official VN-Transformer implementation.

* `--arch vn_transformer_cov`: VN-Transformer baseline with covariance.  
  Assumes a 12D input (3D velocity + 9D covariance matrix) and predicts a 3D velocity output.

* `--arch emlp`: EMLP-based SO(3)-equivariant ResNet baseline operating on velocity-only inputs.
* `--arch emlp_cov`: EMLP-based SO(3)-equivariant ResNet baseline extended to handle both velocity and covariance features.
---

