
<h2 align="center"> <a href="https://arxiv.org/abs/2406.11194">In-Context Editing: Learning Knowledge from Self-Induced Distributions</a></h2>

# About

This project is developed based on [EasyEdit](https://github.com/zjunlp/EasyEdit). Please refer to the original repository for more details of other methods and an overview of knowledge editing. The following is a list of related repositories:

- [EasyEdit](https://github.com/zjunlp/EasyEdit)  An open source knowledge edit framework.
- [ROME](https://github.com/kmeng01/rome)  A related method of Locating and Editing.
- [MEMIT](https://github.com/kmeng01/memit)  A related method of Locating and Editing.

## Table of Contents

<!-- - [🔔News](#news)
- [🌟Overview](#overview) -->
- [🤗Dataset](#dataset)
- [🛠️Requirements and Installation](#️requirements-and-installation)
- [🤖Evaluation](#evaluation)
- [💥Training](#training)


## 🤗Dataset

We evaluate our method using four datasets, **WikiData<sub>recent</sub>**, **ZsRE**, **WikiBio**, **WikiData<sub>counterfact</sub>**. The four datasets share two tasks of knowledge editing to test the generalization of our method.

<table class="tg" align="center" style="border-collapse: collapse; width: 100%;">
<thead>
  <tr>
    <th class="tg-7btt" style="text-align: center;">Task</th>
    <th class="tg-7btt" style="text-align: center;">Knowledge Insertion</th>
    <th class="tg-7btt" colspan="4" style="text-align: center;">Knowledge Modification</th>
  </tr>
</thead>
<tbody>
  <tr>
    <td class="tg-c3ow" style="text-align: center;">Datasets</td>
    <td class="tg-c3ow" style="text-align: center;">WikiData<sub>recent</sub></td>
    <td class="tg-c3ow" style="text-align: center;">ZsRE</td>
    <td class="tg-c3ow" style="text-align: center;">WikiBio</td>
    <td class="tg-c3ow" style="text-align: center;">WikiData<sub>counterfact</sub></td>
  </tr>
  <tr>
    <td class="tg-c3ow" style="text-align: center;">Type</td>
    <td class="tg-c3ow" style="text-align: center;">Fact</td>
    <td class="tg-c3ow" style="text-align: center;">Question Answering</td>
    <td class="tg-c3ow" style="text-align: center;">Hallucination</td>
    <td class="tg-c3ow" style="text-align: center;">Counterfact</td>
  </tr>
  <tr>
    <td class="tg-c3ow" style="text-align: center;"># Train</td>
    <td class="tg-c3ow" style="text-align: center;">570</td>
    <td class="tg-c3ow" style="text-align: center;">10,000</td>
    <td class="tg-c3ow" style="text-align: center;">592</td>
    <td class="tg-c3ow" style="text-align: center;">1,455</td>
  </tr>
  <tr>
    <td class="tg-c3ow" style="text-align: center;"># Test</td>
    <td class="tg-c3ow" style="text-align: center;">1,266</td>
    <td class="tg-c3ow" style="text-align: center;">1,230</td>
    <td class="tg-c3ow" style="text-align: center;">1,392</td>
    <td class="tg-c3ow" style="text-align: center;">885</td>
  </tr>
</tbody>
</table>

You can download data 🤗 [Huggingface Dataset](https://huggingface.co/datasets/Yofuria/ICE). And the expected structure of files is:

```text
ICE
|-- data
|   |-- wikibio.json
|   |-- wikidata_counterfact.json
|   |-- wikidata_recent.json
|   |-- zsre.json
```

## 🛠️Requirements and Installation

```text
# clone ICE
git clone https://github.com/Yofuria/ICE.git
cd ICE

# create conda env
conda create -n ICE python=3.10
conda activate ICE

# install package
pip install -r requirements.txt
```

In **lines 32 and 33** of **`examples/run_knowedit_llama2.py`**, you need to download the **`punkt`** package.

- If your Internet **speed is fast** enough, you can **run the code directly** from the command line.

```text
if __name__ == "__main__":
    # If you have a slow Internet connection and can't download nltk quickly, comment these two lines and use the second method of Requirements and Installation in README.md
    import nltk
    nltk.download('punkt')
```

- If your Internet **speed is slow**, **comment lines 32 and 33** and **download punkt manually**🤗 [punkt](https://huggingface.co/datasets/kailinjiang/punkt). And place it in the ICE environment directory you created, create a **nltk_data/tokenizers** folder, and **unpack punkt** into this directory.

<div align="center">   <img src="assets/punkt.png" width="650px"> </div>

## 🤖Evaluation

You can get the evaluation results using `eval.py`. Evaluation indicators are as follows：

- `rewrite_acc` $\rightarrow$ **Edit Success**[measures the ability of the model to produce the edited response $x^*$ for a query **$q$**]

  <div align="center">
    <img src="assets/edit succ.png" width="450px">
  </div>

- `locality` $\rightarrow$ **Locality** [evaluates if the model maintains original predictions for queries outside the edit scope]

<div align="center">
  <img src="assets/loc.png" width="450px">
</div>

- `portablility` $\rightarrow$ **Portablility** [assesses how well the model generalizes the knowledge for rephrased or logically related queries within the edit scope **$D_q$**]

<div align="center">
  <img src="assets/port.png" width="450px">
</div>

- `ngram_entropy` $\rightarrow$ **Fluency** [estimates the linguistic quality of the postedit model's output, given by a weighted sum of bi- and tri-gram entropies]

<div align="center">
  <img src="assets/flu.png" width="400px">
</div>

- `PPL_r` $\rightarrow$ **PPL<sub>r</sub>** [we introduce a normalized perplexity ratio, comparing the perplexity of the generated sentence beyond the target token to that of the prompt and target token combined]

<div align="center">
  <img src="assets/PPL.png" width="350px">
</div>

<div align="center">
  <img src="assets/PPL_r.png" width="200px">
</div>

After the editing operation, you get a json file with error result data.

```text
{
    "pre": {
        "rewrite_acc": [],
        "portability": {
            "Subject_Aliasing_acc": [],
            "reasoning_acc": []
        },
        "fluency": {
            "ngram_entropy": 
        }
    },
    
    "case_id": 0,
    "requested_rewrite": {
        //...
    },
    
    "time": ,
    "post": {
        "rewrite_acc": [],
        "locality": {
            "Relation_Specificity_acc": [],
            "Forgetfulness_acc": []
        },
        "portability": {
            "Subject_Aliasing_acc": [],
            "reasoning_acc": []
        },
        "fluency": {
            "ngram_entropy": 
        }
    }
}
```

The data used by `PPL_r`is the edit operation that saves the sentences generated by the model.

Such as：`ICE_zsre_Llama-2-7b-chat-hf_gen_sentence.json`

```shell
python eval.py 
    --model_name_or_path=''  # Path to pre-trained model
    --output_file='./FT-M_counterfact_gpt2-xl_gen_sentence.json'  # Generated sentences file (xxx.json)
    --result_file='./FT-M_counterfact_gpt2-xl_results.json'  # Result file (xxx.json)
```

You will get the **following metrics**

```text
Edit_Succ: 30.262626262626263
Portability: 7.3802393354053
Portability (Subject_Aliasing_acc): 6.939620928384972
Portability (reasoning_acc): 3.511697773992855
Portability (Logical_Generalization_acc): 9.11111111111111
Locality: 33.95236461069794
Fluency: 557.8193009507412
ppl_r:  tensor(9.9633, device='cuda:0')
```

## 💥Training

We provide the training hyperparameters for five methods in `./hparams`.

For ICE, we update **GPT2-xl** using **layers 13 to 17** and **Llama2-7b-chat** using **layers 4 to 8**.

Both FT-L and FT-M use the same hparams located in `./hparams/FT`.

For FT-L, replace `objective_optimization` with `prompt_last`, and for FT-M, replace it with `target_new`. For details on other methods, please refer to [EasyEdit](https://github.com/zjunlp/EasyEdit). You can execute the following commands to obtain results:

**For ICE:**

```shell
python examples/run_knowedit_llama2.py \
    --editing_method=ICE \
    --hparams_dir=./hparams/ICE/gpt2-xl.yaml \
    --data_dir=./data/zsre.json \  
    --datatype='zsre' \  
    --metrics_save_dir=./results/gpt2-xl/ICE
```

**For FT-L:**

```shell
python examples/run_knowedit_llama2.py \
    --editing_method=FT-L \
    --hparams_dir=./hparams/ICE/gpt2-xl.yaml \
    --data_dir=./data/zsre.json \  
    --datatype='zsre' \  
    --metrics_save_dir=./results/gpt2-xl/ICE
```

**For FT-M:**

```shell
python examples/run_knowedit_llama2.py \
    --editing_method=FT-M \
    --hparams_dir=./hparams/ICE/gpt2-xl.yaml \
    --data_dir=./data/zsre.json \  
    --datatype='zsre' \  
    --metrics_save_dir=./results/gpt2-xl/ICE
```

**For MEMIT:**

```shell
python examples/run_knowedit_llama2.py \
    --editing_method=MEMIT \
    --hparams_dir=./hparams/ICE/gpt2-xl.yaml \
    --data_dir=./data/zsre.json \  
    --datatype='zsre' \  
    --metrics_save_dir=./results/gpt2-xl/ICE
```

**For ROME:**

```shell
python examples/run_knowedit_llama2.py \
    --editing_method=ROME \
    --hparams_dir=./hparams/ICE/gpt2-xl.yaml \
    --data_dir=./data/zsre.json \  
    --datatype='zsre' \  
    --metrics_save_dir=./results/gpt2-xl/ICE
```

The optional range of `datatype` is `['zsre','recent','counterfact','wikibio']`

**ICE/gpt2-xl.yaml**

```
alg_name: "FT"
model_name: openai-community/gpt2-xl   # or local checkpoint path
device: 0
layers: [13, 14, 15, 16, 17]
num_steps: 25
batch_size: 3
max_length: 40
lr: 7e-4
weight_decay: 0
kl_factor: 0
norm_constraint: 5e-4
grad_norm_constraint: 5e-4
num_return_sequences: 1
max_new_tokens: 3
static_target: False
sample_with_context: True
target_update_interval: 1
temperature: 100.0
print_kl: True

objective_optimization: "target_and_completion_with_context"
rewrite_module_tmp: "transformer.h.{}.mlp.c_proj"
layer_module_tmp: "transformer.h.{}"
mlp_module_tmp: "transformer.h.{}.mlp"
attn_module_tmp: "transformer.h.{}.attn"
ln_f_module: "transformer.ln_f"
lm_head_module: "transformer.wte"
model_parallel: False
```
