# KeyCLD: Learning Constrained Lagrangian Dynamics in Keypoint Coordinates from Images

## Installation

### MuJoCo

Download [MuJoCo 2.1.1](https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz).
Unpack archive and place at `~/.mujoco/mujoco-2.1.1/`.

### KeyCLD

Install `jax==0.3.13` and `jaxlib==0.3.13` with cuda support: https://github.com/google/jax#pip-installation-gpu-cuda.
Optionally install JAX with CPU support (training will be very slow):
```
pip install jax==0.3.13 jaxlib==0.3.13
```

Clone the repository and install KeyCLD:
```
pip install .
```

Replace the relevant `xml` files in `dm_control` with the adapted files in `dm_control_environments`.

### Weights and Biases

KeyCLD uses wandb.ai for logging training and results.
If you are not familiar with wandb, follow the instructions when first running the code.

## Reproduce Experiments

Run the commands below to reproduce all experiments reported in the paper.
Check the results in wandb.
Or take a look in the `notebooks` dir to interact with the trained models.

### KeyCLD
```
python keycld/dm.py --environment=pendulum --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python keycld/dm.py --environment=pendulum --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python keycld/dm.py --environment=cartpole --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python keycld/dm.py --environment=cartpole --init_mode=random --control=underactuated --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python keycld/dm.py --environment=cartpole --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python keycld/dm.py --environment=acrobot --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python keycld/dm.py --environment=acrobot --init_mode=random --control=underactuated --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python keycld/dm.py --environment=acrobot --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
```

### KeyLD
```
python ablations/no_constraint.py --environment=pendulum --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/no_constraint.py --environment=pendulum --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/no_constraint.py --environment=cartpole --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/no_constraint.py --environment=cartpole --init_mode=random --control=underactuated --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/no_constraint.py --environment=cartpole --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/no_constraint.py --environment=acrobot --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/no_constraint.py --environment=acrobot --init_mode=random --control=underactuated --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/no_constraint.py --environment=acrobot --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
```

### KeyODE2
```
python ablations/node2.py --environment=pendulum --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/node2.py --environment=pendulum --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/node2.py --environment=cartpole --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/node2.py --environment=cartpole --init_mode=random --control=underactuated --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/node2.py --environment=cartpole --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/node2.py --environment=acrobot --init_mode=random --control=no --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/node2.py --environment=acrobot --init_mode=random --control=underactuated --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
python ablations/node2.py --environment=acrobot --init_mode=random --control=yes --batch_size=1 --dynamics_weight=1 --learning_rate=0.0003 --num_epochs=40 --num_hidden_dim=32 --low_data_regime=false --num_predicted_steps=4
```

### Lag-caVAE & Lag-VAE & HGN
```
python ablations/lagcavae.py --control no
python ablations/lagcavae.py --control yes
python ablations/lagcavae.py --control no --low_data_regime true
python ablations/lagcavae.py --control yes --low_data_regime true
python ablations/lagcavae.py --control no --ca
python ablations/lagcavae.py --control yes --ca
python ablations/lagcavae.py --control no --low_data_regime true --ca
python ablations/lagcavae.py --control yes --low_data_regime true --ca
python ablations/hgn.py --environment pendulum
python ablations/hgn.py --environment pendulum --low_data_regime true
python ablations/hgn.py --environment cartpole
python ablations/hgn.py --environment cartpole --low_data_regime true
python ablations/hgn.py --environment acrobot
python ablations/hgn.py --environment acrobot --low_data_regime true
```
