# Imitation learning with distribution shifts

This folder has code to reproduce results for the imitation learning example in the paper. The code builds on the following paper and repositories:
* [Transporter Networks: Rearranging the Visual World for Robotic Manipulation](https://transporternets.github.io/), Conference on Robot Learning (CoRL) 2020, 
  *Andy Zeng, Pete Florence, Jonathan Tompson, Stefan Welker, Jonathan Chien, Maria Attarian, Travis Armstrong, Ivan Krasin, Dan Duong, Vikas Sindhwani, Johnny Lee*.
* [PyTorch - Transporter Networks](https://github.com/thomaschabal/transporter-nets-torch/tree/main)
* [Ravens - Transporter Networks](https://github.com/google-research/ravens)

## Installation

```
conda create -n drm-imitation pip python==3.7.0
conda activate drm-imitation
pip install -r requirements.txt
python setup.py install --user
```

Setup environment variables:
```
export RAVENS_ASSETS_DIR=`pwd`/ravens_torch/environments/assets/;
export WORK=`pwd`;
export PYTHONPATH=`pwd`:$PYTHONPATH
```


## Running the code

First, generate demonstration data in training environments, and a number of test environments:
```
python ravens_torch/demos_ood.py --disp=False --mode=train --n=300
python ravens_torch/demos_ood.py --disp=False --mode=test --n=20
```

Next, train models using behavior cloning (ERM) and DRM. This will perform training across 10 seeds:
```
python ravens_torch/generate_results.py
```

To plot training and test success rates across seeds, run:
```
python ravens_torch/plot_results.py
```

If you want to plot martingale values (as in the paper) for the BC and DRM policies, run:
```
python ravens_torch/plot_martingale.py
```

In order to generate videos of the BC or DRM policy operating in test environments, you can run:
```
python ravens_torch/test_ood.py --disp=False --n_demos=300 --n_steps=120 --record_mp4=True --method=drm
```
This will save a video to data/videos/ folder. In the above command, you can set disp=True and record_mp4=False if you just want to visualize the policy running without recording a video. Change the "method" flag to "bc" if you want to visualize the BC policy instead. 

## Improvements

For simplicity, the training code is currently not batched efficiently, and runs slowly as a result. Efficient batching should lead to significant speeds ups. 