# k conv basis Implementation and Visualization

## Prerequisites
1. Please refer to the README.md in the llama3 directory for the prerequisites.
2. Clone text prompt we use https://github.com/gkamradt/LLMTest_NeedleInAHaystack 
```bash
git clone https://github.com/gkamradt/LLMTest_NeedleInAHaystack
```
replace the variable `text_root_dir` in `save_attention_matrix_from_gpt2.py` with the path to the cloned directory.

## K conv basis FFT unit test and visualization
Please run the following command to run the unit test for k conv basis computation and visualization:
```bash
python unit_test.py
```

## Visualization of attention matrix from Meta-Llama-3-8B
Please run the following command to visualize the attention from Meta-Llama-3-8B:
```bash
torchrun --nproc_per_node 1 visualize_attention_llama3.py
```

## Save attention matrix from GPT2 and run comparison test between naive attention and k conv basis
To save (visualize) the attention from GPT2, first install huggingface transformers and replace the file modeling_gpt2.py to the path `transformers/src/transformers/models/gpt2/modeling_gpt2.py`. Then run the following command:
```bash
cp modeling_gpt2.py {path_to_transformers_lib}/models/gpt2/modeling_gpt2.py
```
Then run the following command to save the attention matrix from GPT2:
```bash
python visualize_attention_gpt2.py
```
Then compute and save result for different k values:
```bash
zsh run_k_conv_basis.sh
```

Finally, to visualize the result, run the following command:
```bash
python plot_k_conv_basis.py
```

### torch version attention
#### compare naive vs conv_attn in np
`gpt2_np_attn.py` contains the naive attention and conv attention with flops computed by hand. The results compare the error and flops under different `k`.
```
bash scripts/run_k_conv_basis.sh
```

#### compare conv_attn in np and torch
Compare the implementation in numpy and torch error. Will pass the test under `tol=1e-4`.
```
bash debug/unit_test_attn.py
```

#### compare naive vs conv_attn in gpt2 torch
```
python debug/gpt2_forward.py
```

#### compare gpt2 hidden states on one sample
found a lot of `NaNs`, see output of below command.
```
python debug/gpt2_conv_attn.py
```

#### infer gpt2 on imdb, and save hidden states
if above `NaN` issue is solved, run it on batch examples.

run with naive attention:
```
python infer_gpt2_hidden.py --naive
```

run with conv attention:
```
python infer_gpt2_hidden.py --k 10
```

- run batch experiments:
```
bash scripts/run_infer_gpt2.sh
```
modify whether to run `infer_gpt2.py` or `infer_gpt2_hidden.py`

- see results of generated hidden states:

```
cd out
python see_res.py
```
adjust `constant` acordingly.
get `errors.json` recording details

then see `visz/see_error.ipynb`.

###### implement naive conv
see `visz/test_shift.ipynb`.
same as before step, just replace `_conv_attn()` with `_naive_conv_attn()` in `Conv_GPT2Attention:forward()`

### add support for causal generation for gpt2
transfer task into causalLM task.
```
python infer_gpt2.py --naive
python infer_gpt2.py --k 10
```

### Llama naive conv results

#### compare gpt2 hidden states on one sample
```
python debug/llama_conv_attn.py
```

- run batch experiments to see hidden:
```
python scripts/run_infer_all_gpus_para_sample.sh
```
modify to run `infer_llama_hidden.py`



- run batch experiments to see generation on imdb:
```
python scripts/run_infer_all_gpus_para_sample2.sh
```
modify to run `infer_llama.py`

----------

Note: `run_infer_all_gpus_para_sample.sh` and such limit one GPU per process, which will cause OOM in long seq_len, use `run_infer_k_loop.sh` instead.

### Llama conv attn naive select k
see how to achieve in 
```
python debug/test_shift.py
python debug/test_shift_parallel.py
```

then replace the operation in `model_llama.py`

