## Description

This repo contains all the necessary code to construct a MoL model that will enable multimodal fine-tuning of most LLMs loaded from the `transformers` library.

Once the MoL model is constructed, we leave to the user the implementation of the training pipeline.

### Repo description
```
model
  ├──  ffn.py                 # per-modality FFN 
  ├──  layernorm.py           # per-modal layernorm (or equivalent)
  ├──  mixture_modules.py     # mixture of LoRA modules
  └──  MoL.py                 # LLM with attention replaced with MoL modules.
```
### How to use ?

#### Libraries

```bash
pip install torch transformer
```

#### Loading the MoL model
```python
from transformers import AutoModelForCausalLM
from src.model.mol import MixtureOfLoRAModel
import torch

mol_config = MoLConfig(
            modalities=['image','text','audio'],
            trainable_modalities=['image','text','audio'],
            lora_rank=16,
            lora_alpha=16,
            use_modality_specific_ln=True, # whether to use per-modality LayerNorm
            use_modality_specific_ffn=False, # whether to use per-modality FFN
            use_lora_ffn=True, # whether to use MoL-augmented FFN
            text_lora_enabled=True, # whether to include a MoL adapter for the text modality
            baseline_lora=False, # whether to include a single shared LoRA adapter
        )

llm = AutoModelForCausalLM.from_pretrained(
            path_to_llm,
            torch_dtype=torch.bfloat16, #choose any
            low_cpu_mem_usage=True,
            local_files_only=False,
            attn_implementation='sdpa', # choose any
        )

mol_model = MixtureOfLoRAModel(
            base_model=llm,
            config=mol_config,
            llm_config=llm.config,
        )
```

#### Forward

First generate the embeddings and modality mask.

```python
## multi token tokens handling
if 'image' in  modalities:
    for i in range(len(image_tokenizer.codebook)):
        token_id_to_modality[text_tokenizer.convert_tokens_to_ids(
            f"<image_token_{i}>")] = 'image'
if 'audio' in  modalities:
    for i in range(len(audio_tokenizer.codebook)):
        token_id_to_modality[text_tokenizer.convert_tokens_to_ids(
            f"<audio_token_{i}>")] = 'audio'

all_special_tokens = [text_tokenizer.pad_token]
if 'audio' in modalities:
    all_special_tokens.extend(['<audio>', '</audio>'])
if 'image' in modalities:
    all_special_tokens.extend(['<image>', '</image>'])
all_special_tokens_ids = text_tokenizer.convert_tokens_to_ids(
            all_special_tokens
        )

inputs_embeds = llm.get_input_embeddings()(input_ids)

# Generate modality masks
# Convert token IDs to lists of tokens for easier string matching
all_modalities = ['text', 'image', 'audio']
modality_masks = {
    mod: torch.zeros_like(
    input_ids, dtype=torch.bool, device=input_ids.device)
    for mod in all_modalities
}

# Assuming everything is text by default
modality_masks['text'].fill_(True)

# Carve out the image and audio modalities
for mod_name, mod_token_map in [
    ('image', token_id_to_modality), 
    ('audio', token_id_to_modality)
    ]:
    # Create a combined mask for all tokens belonging to the current modality
    modality_mask = torch.zeros_like(input_ids, dtype=torch.bool)
    for token_id in mod_token_map:
        if mod_token_map[token_id] == mod_name:
            modality_mask |= (input_ids == token_id)

    modality_masks[mod_name] = modality_mask
    # Exclude these tokens from the text mask
    modality_masks['text'] &= ~modality_mask

for token_id in all_special_tokens_ids:
    modality_masks['text'][input_ids == token_id] = False
```

Once embedding and modality masks are obtained, one can use the instanciated mol_model as such,

```python
outputs = mol_model(
    inputs_embeds=inputs_embeds,
    attention_mask=attention_mask,
    modality_mask=modality_mask,
    labels=labels,
    output_per_modality_loss=True,
    condition_on_first_modality=False,
    **kwargs
    )
logits = outputs.logits
per_modality_loss = outputs.per_modality_loss
```