# Energy-based Action Advising for Tranfer Learning

## set up environment 
pip install -r requirement.txt

change to overcooked_ai, pip install -e .

train from scratch:
python main.py --config multi_grid

transfer learning (replace the teacher_dir correspondingly):
python transfer.py --config multi_grid_locked --teacher_dir ckpts/iac-multi_grid-b5b14c30-24_02_21-05_36_57

## Pipeline

For full usage and details of all scripts and notebooks mentioned below, please refer to the documentation under the `Documentation/` directory.

1. **Generate encodings for IID and OOD samples**

The encodings should be in shape:

```
[batch_size, C, X, Y]
```

Where:
- `C` is the number of feature channels (e.g., `3` in GridWorld),
- `X` and `Y` are the spatial dimensions of the environment encoding (e.g., grid size),
- `batch_size` is the number of stored IID/OOD samples.

You have several options to generate the encodings:

- **Generate random encoding** based on your requirements.  
- **Specific encoding** by forcing the objects and the agent position.  
- **Get the encoding during evaluation** by running `evaluate.py` and extracting the frames later.

   - For `evaluate.py`, you can follow the teacher policy or take random actions. (For Overcooked environments, make sure to **add a conditional check in the `if td["done"].any()` block to skip some lines** to avoid errors due to environment-specific edge cases.
)
   - After the `.pt` file is saved, you can use the script below to extract the `"images"` field (which is the encoding):

     ```bash
     python EncodingScript/extract_frames_eval.py --input_file <path_to_pt_file>
     ```



2. **Run main.py** to train the source policy

```bash
python main.py --config <src_env_config>
```

This will train the policy and save model checkpoints under a generated directory inside `./ckpts`.

---

3. **Run distribution.py** to generate energy scores

```bash
python distribution.py --config <src_env_config> --ckpt_dir <path_to_ckpt>
```

This will:
- Generate a list of raw energy values using the IID encodings.
- Save a `.pt` file and a corresponding plot in the IID directory.

---

4. **Run transfer.py** to transfer the source policy to a target environment

```bash
python transfer.py --config <tgt_env_config> --teacher_dir <path_to_ckpt>
```

Use the directory from Step 2 as `--teacher_dir`.

---

5. **Run the notebook** to gather and analyze rewards

```bash
notebooks/q_reward.ipynb
```

This notebook gathers rewards and allows you to analyze them based on advice conditions or energy thresholds.
