# Hierarchical Few-Shot Imitation with Skill Transition Models

This repository is the implementation of Hierarchical Few-Shot Imitation with Skill Transition Models.

The code is based off of [SPiRL](https://github.com/clvrai/spirl).

## Requirements

To install requirements:

```
pip install -r requirements.txt
```
Then run ```pip install -e .```.

## Datasets

The datasets for Kitchen is attached in the zip file. The dataset for the AntMaze environment can be created by running the following script:

```
./data/antmaze/get_antmaze_dsets.sh
```

## Training

Before training, set the PATH variables by running
```
source ~/.bashrc
```
First, train the contrastive distance metric by running:
```
python contrastive_reachability.py --env=<ENV> --training_set=<PATH-TO-OFFLINE-DATA> --demos=<PATH-TO-FINE-TUNE-DEMOS> --save_dir=./experiments/<ENV>/contrastive/
```
For example, to train the metric on AntMaze, run:
```
python contrastive_reachability.py --env=maze --training_set=./data/antmaze/Antmaze_filtered_LR.hdf5 --demos=./data/antmaze/Antmaze_LR.pkl --save_dir=./experiments/antmaze/contrastive/
```
#### Skill Extraction

To train the skill encoder/decoder and inverse skill model, run:
```
python3 spirl/train.py  --path <PATH> --val_data_size 160
```
All config files are located in ```spirl/configs/skill_prior_learning```

#### Fine-tuning
For finetuning, ```ckpt_path``` in the config file should be given as the latest checkpoint epoch of the skill extractor run, and then run
```
CUDA_VISIBLE_DEVICES=0 python3 spirl/fewshot_kitchen_train.py --path <PATH> --val_data_size 160 --resume <EPOCH>
```
#### Few-shot Learning
For evaluation, comment out the ckpt path and run:
```
CUDA_VISIBLE_DEVICES=0 python3 spirl/fewshot_train.py --path <PATH> --val_data_size 160 --resume 49 --eval 1
```

For evaluation on the none finetuned version, keep the ckpt_path and run:
```
CUDA_VISIBLE_DEVICES=0 python3 spirl/fewshot_train.py --path <PATH> --val_data_size 160 --resume 199 --eval 1
```
For few-shot learning with the kitchen environment, run ```spirl/fewshot_kitchen_train.py``` instead of ```spirl/fewshot_train.py```.
