<p align="center">
  <em><strong>LogIX</strong>: Logging for Interpretable and Explainable AI <br></em>
</p>

```bash
pip install -e .
```

## Basics
It turns out that most _interpretable & explainable AI_ research (e.g., training data attribution,
saliency maps, mechanistic interpretability) simply require **(1)** intercepting various training logs
(e.g., activation, gradient) and **(2)** doing some computational analyses with these logs. Therefore,
**LogIX** focuses on simple, efficient, and interoperable logging of training artifacts for maximal
flexibility, while providing some pre-implemented interpretability algorithm (e.g., influence functions)
for general users.


## Usage
### Logging
Training log extraction with LogIX is as simple as adding one `with` statement to the existing
training code. LogIX automatically extracts user-specified logs using PyTorch hooks, and stores
it as a tuple of `([data_ids], log[module_name][log_type])`. If needed, LogIX writes these logs
to disk efficiently with memory-mapped files.

```python
import logix

# Initialze LogIX
run = logix.init(project="my_project")

# Users can specify artifacts they want to log
run.setup({"log": "grad", "save": "grad", "statistic": "kfac"})

# Users can specify specific modules they want to track logs for
run.watch(model, name_filter=["mlp"], type_filter=[nn.Linear])

for input, target in data_loader:
    # Set data_id for the log from the current batch
    with run(data_id=input):
        out = model(input)
        loss = loss_fn(out, target, reduction="sum")
        loss.backward()
        model.zero_grad()

    # Access log extracted in the LogIX context block
    log = run.get_log() # (data_id, log_dict)
    # For example, users can print gradient for the specific module
    # print(log[1]["model.layers.23.mlp.down_proj"]["grad"])
    # or perform any custom analysis

# Synchronize statistics (e.g. grad variance) and write logs to disk
run.finalize()
```

### Training Data Attribution
As a part of our initial research, we implemented influence functions using LogIX. We plan to provide more
pre-implemented interpretability algorithms if there is a demand.

```python
# Build PyTorch DataLoader from saved log data
log_loader = run.build_log_dataloader()

with run(data_id=test_input):
    test_out = model(test_input)
    test_loss = loss_fn(test_out, test_target, reduction="sum")
    test_loss.backward()
test_log = run.get_log() # extract a log for test data

run.influence.compute_influence_all(test_log, log_loader) # data attribution
run.influence.compute_self_influence(test_log) # uncertainty
```

### HuggingFace Integration
Our software design allows for the seamless integration with HuggingFace's
[Transformer](https://github.com/huggingface/transformers/tree/main), a popular DL framework
that conveniently handles distributed training, data loading, etc. We plan to support more
frameworks (e.g. Lightning) in the future.

```python
from transformers import Trainer, Seq2SeqTrainer
from logix.huggingface import patch_trainer, LogIXArguments

logix_args = LogIXArguments(project, config, lora=True, ekfac=True)
LogIXTrainer = patch_trainer(Trainer)

trainer = LogIXTrainer(logix_args=logix_args, # pass LogIXArguments as TrainingArguments
                       model=model,
                       train_dataset=train_dataset,
                       *args,
                       **kwargs)

# Instead of trainer.train(),
trainer.extract_log()
trainer.influence()
trainer.self_influence()
```

Please check out [Examples](/examples) for more advanced features!


## Features
Logs from neural networks are difficult to handle due to the large size. For example,
the size of the gradient of *each* training datapoint is about as large as the whole model. Therefore,
we provide various systems support to efficiently scale neural network analysis to
billion-scale models. Below are a few features that LogIX currently supports:

- **Gradient compression** (compression ratio: 1,000-100,000x)
- **Memory-map-based data IO**
- **CPU offloading of statistics**

## Compatability
| DistributedDataParallel| Mixed Precision| Gradient Checkpointing | torch.compile  | FSDP           |
|:----------------------:|:--------------:|:----------------------:|:-------------:|:--------------:|
| ✅                     | ✅             | ✅                    | ✅           |   ✅             |

## Contributing

We welcome contributions from the community. Please see our [contributing
guidelines(Deleted for Anonymization)](CONTRIBUTING.md) for details on how to contribute to LogIX.

## Citation
To cite this repository:

```
@ deleted for anonymiztion
```

## License
LogIX is licensed under the [Apache 2.0 License](LICENSE).
