# File Structure

* data: all data for experiments
  * mlp: data for MLP model;
  * cdt: data for CDT model;
  * sdt: data for SDT model;
  * il: data for general Imitation Learning (IL);
  * rl: data for general Reinforcement Learning (RL);
* src: source code
  * mlp: training configurations for MLP as policy function approximator;
  * cdt: the Cascading Decision Tree (CDT) class and necessary functions;
  * sdt: the Soft Decision Tree (SDT) class and necessary functions;
  * hdt: the heuristic agents;
  * il: configurations for Imitation Learning (IL);
  * rl: configurations for Reinforcement Learning (RL) and RL agents (e.g., PPO) etc;
  * utils: some common functions
  * `il_data_collect.py`: collect dataset (state-action from heuristic or well-trained policy) for IL;
  * `rl_data_collect.py`: collect dataset (states during training for calculating normalization statistics) for RL;
  * `il_train.py`: train IL agent with different function approximators (e.g., SDT, CDT);
  * `rl_train.py`: train RL agent different function approximators (e.g., SDT, CDT, MLP);
  * `il_eval.py`: evaluate the trained IL agents before and after tree discretization, based on prediction accuracy;
  * `rl_eval.py`: evaluate the trained RL agents before and after tree discretization, based on episodic reward;
  * `il_train.sh`: bash to run IL on server;
  * `rl_train.sh`: bash to run RL on server.
* visual
  * plot.ipynb: plot learning curves, etc.

# To Start

For fully replicating the experiments in the paper, the code needs to run in several stages.

### A. Reinforcement Learning Comparison with SDT, CDT and MLP

1. Collect dataset

   ``` bash
   cd ./src
   python il_data_collect.py
   ```

2. Get statistics on dataset

   ````bash
   cd rl
   jupyter notebook
   ````

   open `stats.ipynb` and run cells in it to generate files for dataset statistics.

3. Train RL agents with different policy function approximators: SDT, CDT, MLP

   ```bash
   cd ..
   python rl_train.py --train --method='sdt' --id=0
   python rl_train.py --train --method='cdt' --id=0
   python rl_train.py --train --method='mlp' --id=0
   ```

   or simply run with:

   ````bash
   ./rl_train.sh
   ````

4. Evaluate the trained agents

   ````bash
   python rl_eval.py --method='sdt'
   python rl_eval.py --method='cdt'
   python rl_eval.py --method='mlp'
   ````

   

### B. Imitation Learning Comparison with SDT and CDT

1. Collect dataset

   ```bash
   cd ./src
   python il_data_collect.py
   ```

2. Train RL agents with different policy function approximators: SDT, CDT

   ```bash
   python il_train.py --train --method='sdt' --id=0
   python il_train.py --train --method='cdt' --id=0
   ```

   or simply run with:

   ```bash
   ./il_train.sh
   ```

3. Evaluate the trained agents

   ```bash
   python il_eval.py --method='sdt'
   python il_eval.py --method='cdt'
   ```

   

